In this assignment I have explored creating my own models from scratch and also explored many different architectures and have implemented it into my CA1 Deep Learning assignment these architecture includes, ResNet, DenseNet, ResNeXt and CoAtNet.
(Part A: ResNet, DenseNet) | (Part B: ResNeXt, CoAtNet)
The CIFAR100 dataset also known as, Candian Institute for Advanced Research 100, is a subset of the Tiny Images dataset and consists of 60000 32x32 color images. The 100 classes in the CIFAR-100 are grouped into 20 superclasses with 5 subclasses for each superclass.
The goal is to build, analyze and evaluate multiple models using the CIFAR100 dataset and eventually settle on a model that is able to generalize well to new data without overfitting.
All code is original unless attributed
!nvidia-smi -L
GPU 0: NVIDIA GeForce RTX 3080 Ti (UUID: GPU-0e42f8c3-6223-f1d6-5ca0-57d93c5a3520) GPU 1: NVIDIA GeForce RTX 3080 Ti (UUID: GPU-ce810bb2-3e47-b722-613a-d3358dbf87ba)
!pip install matplotlib
!pip install seaborn
!pip install torch-summary
!pip install -U scikit-learn
!pip install -U matplotlib
!pip install einops
!pip install tensorflow
# !pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116
import torch, re, time, gc, itertools, random, os ,sys, random
from IPython.display import clear_output
import torch.optim as optim
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from torch.utils.data import Dataset, TensorDataset, DataLoader
from torchvision import datasets
import torchvision.transforms as transforms
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
from copy import copy, deepcopy
import pandas as pd
from collections import OrderedDict
from typing import Any, Callable, List, Optional, Type, Union, Tuple
from sklearn.metrics import classification_report, confusion_matrix
from torchsummary import summary
from einops import rearrange
from einops.layers.torch import Rearrange
# Tensorflow for the coarse labelled data only
from tensorflow.keras.datasets.cifar100 import load_data
# from sklearn.preprocessing import OneHotEncoder
sns.set(style="ticks")
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Again, this is a relative high batch size (chosen for faster computation). Stronger and better data augmentation has to be done
BATCH_SIZE = 512
Turning raw data into torch.TensorFloat so I can use it for EDA
Class Labels from Metadata:
Format (class number range): Superclass - Subclasses:
Just by seeing the labels and with simple domain knownledge, I can tell that this is going to be a much harder classification problem than FashionMNIST. There is a lot more classes, furthermore, some classes are pretty similar to one another, infact subclasses of their superclass will have similar features and our model needs robust enough to be able to extract the minor differences between some of the features.
training_data = datasets.CIFAR100(
root="data",
train=True,
download=True,
transform=transforms.ToTensor()
)
test_data = datasets.CIFAR100(
root="data",
train=False,
download=True,
transform=transforms.ToTensor()
)
class_labels = training_data.classes
train_loader = DataLoader(training_data, batch_size=len(training_data))
test_loader = DataLoader(test_data, batch_size=len(test_data))
train_data = torch.Tensor(next(iter(train_loader))[0].numpy())
test_data = torch.Tensor(next(iter(test_loader))[0].numpy())
train_label = torch.Tensor(next(iter(train_loader))[1].numpy())
test_label = torch.Tensor(next(iter(test_loader))[1].numpy())
del train_loader, test_loader, training_data
Files already downloaded and verified Files already downloaded and verified
(_, train_label_coarse), (_, test_label_coarse) = load_data('coarse')
train_label_coarse = train_label_coarse.reshape(len(train_label_coarse))
test_label_coarse = test_label_coarse.reshape(len(test_label_coarse))
# Referenced from Kaggle link[Section 1.1] (official metdadata from toronto university)
def unpickle(file):
import pickle
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
metadata_path = './meta'
metadata = unpickle(metadata_path)
superclass_labels = list(metadata[b'coarse_label_names'])
Function to plot distribution of dataset with their classes
def plotDist(y_data,class_labels):
class_count = ''
labels, counts = np.unique(y_data, return_counts=True)
c = 0
for label, count in zip(labels, counts):
c += 1
class_count += f"{class_labels[int(label)]}: {count}, "
if c % 5 ==0:
class_count += '\n'
print(class_count)
fig, ax = plt.subplots()
fig.set_size_inches(15, 4)
g = sns.barplot(x=class_labels, y=counts, ax=ax)
ax.tick_params(axis='x', rotation=90)
g.tick_params(labelsize=7)
plt.grid()
plt.tight_layout()
Fine/subclasses classes
plotDist(train_label,class_labels)
apple: 500, aquarium_fish: 500, baby: 500, bear: 500, beaver: 500, bed: 500, bee: 500, beetle: 500, bicycle: 500, bottle: 500, bowl: 500, boy: 500, bridge: 500, bus: 500, butterfly: 500, camel: 500, can: 500, castle: 500, caterpillar: 500, cattle: 500, chair: 500, chimpanzee: 500, clock: 500, cloud: 500, cockroach: 500, couch: 500, crab: 500, crocodile: 500, cup: 500, dinosaur: 500, dolphin: 500, elephant: 500, flatfish: 500, forest: 500, fox: 500, girl: 500, hamster: 500, house: 500, kangaroo: 500, keyboard: 500, lamp: 500, lawn_mower: 500, leopard: 500, lion: 500, lizard: 500, lobster: 500, man: 500, maple_tree: 500, motorcycle: 500, mountain: 500, mouse: 500, mushroom: 500, oak_tree: 500, orange: 500, orchid: 500, otter: 500, palm_tree: 500, pear: 500, pickup_truck: 500, pine_tree: 500, plain: 500, plate: 500, poppy: 500, porcupine: 500, possum: 500, rabbit: 500, raccoon: 500, ray: 500, road: 500, rocket: 500, rose: 500, sea: 500, seal: 500, shark: 500, shrew: 500, skunk: 500, skyscraper: 500, snail: 500, snake: 500, spider: 500, squirrel: 500, streetcar: 500, sunflower: 500, sweet_pepper: 500, table: 500, tank: 500, telephone: 500, television: 500, tiger: 500, tractor: 500, train: 500, trout: 500, tulip: 500, turtle: 500, wardrobe: 500, whale: 500, willow_tree: 500, wolf: 500, woman: 500, worm: 500,
Coarse/superclasses
plotDist(train_label_coarse,superclass_labels)
b'aquatic_mammals': 2500, b'fish': 2500, b'flowers': 2500, b'food_containers': 2500, b'fruit_and_vegetables': 2500, b'household_electrical_devices': 2500, b'household_furniture': 2500, b'insects': 2500, b'large_carnivores': 2500, b'large_man-made_outdoor_things': 2500, b'large_natural_outdoor_scenes': 2500, b'large_omnivores_and_herbivores': 2500, b'medium_mammals': 2500, b'non-insect_invertebrates': 2500, b'people': 2500, b'reptiles': 2500, b'small_mammals': 2500, b'trees': 2500, b'vehicles_1': 2500, b'vehicles_2': 2500,
Off the bat, I can tell that there is the same number of datapoints for each class. It has a uniform distribution in terms of type of datapoints. With a sample of only 500 samples per subclassesclass and 2500 samples per superclasses it might be hard to make a model that can generalize well to unseen data, especially for the subclasses (fine classes) and after train-validation split as well. More important techniques to improve our model would be our data augmentation, since our data is very limited we need to replicate our dataset and perhaps use a much lower batch_size compared to Part A (FashionMNIST Classification).
train_data.numpy().shape
(50000, 3, 32, 32)
The data is not in the right shape & datatype to display using imshow so we will have a function to transpose our data
def img4np(tensor):
tensor = np.swapaxes(tensor.numpy(),1,-1)
return np.swapaxes(tensor,1,2)
train_data_np = img4np(train_data)
train_data_np.shape
(50000, 32, 32, 3)
Lets take a look at an image from every class
fig, ax = plt.subplots(20, 10, figsize=(15, 30))
for i in range(20):
for j in range(10):
label = superclass_labels[i]
images = train_data_np[np.squeeze(train_label_coarse == i)]
subplot = ax[i, j]
subplot.axis("off")
subplot.imshow(images[i+j])
if j == 0:
subplot.set_title(f"{label}: "[2:-1])
else:
pass
Lets take a look at superclasses of our dataset (fine)
fig, ax = plt.subplots(20, 5, figsize=(15, 30))
for i in range(20):
for j in range(5):
label = class_labels[i*5+j]
images = train_data_np[np.squeeze(train_label == (i)*5+j)]
subplot = ax[i, j]
subplot.axis("off")
subplot.imshow(images[0])
subplot.set_title(f"Label: {label}")
Displaying the average image for subclasses
fig, ax = plt.subplots(20, 5, figsize=(15, 30))
for i in range(20):
for j in range(5):
label = class_labels[i*5+j]
# images = train_data_np[np.squeeze(train_label == (i)*5+j)]
avg_image = np.mean(train_data_np[np.squeeze(train_label == (i)*5+j)], axis=0)
subplot = ax[i, j]
subplot.axis("off")
subplot.imshow(avg_image)
subplot.set_title(f"Label: {label}")
Displaying the average image for superclasses
fig, ax = plt.subplots(4, 5, figsize=(25, 15))
for i in range(4):
for j in range(5):
label = superclass_labels[i*4+j]
avg_image = np.mean(train_data_np[np.squeeze(train_label_coarse == i*4+j)], axis=0)
subplot = ax[i, j]
subplot.axis("off")
subplot.imshow(avg_image)
subplot.set_title(f"Label: {label}")
Displaying the RGB average for all 20 superclasses or 100 subclasses is too much and rather incomprehensible. Therefore, I will just display the average as a whole(CIFAR100 images) and see whether there is any insightful analysis to get out off
From all the images that I have shown above. This is definitely going to be a much harder dataset to deal with compared to FashionMNIST. It has more classes and the images within classes differ quite a lot with one another, infact from the average images we could barely make out some of the images, especially on the fine classes. Furthermore, the images between classes do not differ much as well, infact there is at least 5 classes categorised under the same superclass (coarse label). However, the dataset also provide much more details compared to the greyscaled FashionMNIST, CIFAR100 has 3 channels and 32x32 and with 3 channels we can make out much more details in our images. All in all, although CIFAR100 images provide more information it also provided a much more complex problem compared to FashionMNIST, more complex models and stronger data augmentation are expected
plt.imshow(np.mean(train_data_np[:,:,:,0], axis=0),cmap=mpl.colormaps['Reds'])
plt.title('Red Average')
plt.show()
plt.imshow(np.mean(train_data_np[:,:,:,1], axis=0),cmap=mpl.colormaps['Greens'])
plt.title('Green Average')
plt.show()
plt.imshow(np.mean(train_data_np[:,:,:,2], axis=0),cmap=mpl.colormaps['Blues'])
plt.title('Blue Average')
plt.show()
plt.imshow(np.mean(train_data_np, axis=0))
plt.title('RBG Average')
plt.show()
The overall RGB looks exactly like a grey blank image with maybe something abit more red-orange at the middle as seen by the higher red value in the middle in the Red value image. A grey image is not unsurprising as there are so many diferent kind of classes in our CIFAR100 images that the overall average image would just be an incomprehensible mess. This also sort of tells me that using one general autoencoder might be hard to identify true outliers especially with 100 classes, since there is a lot of classes to compress and generalize by the autoencoder. One way to overcome this is to train multiple autoencoders on seperate classes. However, to cut down on time I will still use one general but robust autoencoder to identify outliers.
It would quite impractical to view all outliers of 100 classes so I will just view the top 25 outliers of our dataset and observe their features and class as well.
# This autoencoder is original
class SimpleAutoEncoder(nn.Module):
def __init__(self):
super(SimpleAutoEncoder, self).__init__()
# encoder
self.encoder = nn.Sequential(
nn.Conv2d(3, 128, 5,1,padding=1), nn.BatchNorm2d(128), nn.ReLU(True),
nn.Conv2d(128, 64,3), nn.BatchNorm2d(64), nn.ReLU(True),
nn.Conv2d(64, 32, 3), nn.BatchNorm2d(32), nn.ReLU(True),
nn.Conv2d(32, 16, 3), nn.BatchNorm2d(16), nn.ReLU(True)
)
# bottleneck
self.bottleneck = nn.Sequential(
nn.Flatten(),
nn.Linear(9216,384), nn.BatchNorm1d(384), nn.ReLU(True),
nn.Linear(384,9216), nn.BatchNorm1d(9216), nn.ReLU(True),
nn.Unflatten(1,(16,24,24)) # 9216 = 16 * 24 * 24 (calculation done manually)
)
# decoder
self.decoder = nn.Sequential(
nn.ConvTranspose2d(16, 32, 3), nn.BatchNorm2d(32), nn.ReLU(True),
nn.ConvTranspose2d(32, 64, 3), nn.BatchNorm2d(64), nn.ReLU(True),
nn.ConvTranspose2d(64, 128, 3), nn.BatchNorm2d(128), nn.ReLU(True),
nn.ConvTranspose2d(128, 3, 5,1,padding=1), nn.ReLU(True),
)
def forward(self, x):
x = self.encoder(x)
x = self.bottleneck(x)
x = self.decoder(x)
return x
model = nn.DataParallel(SimpleAutoEncoder())
model = model.to(device)
print(model)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr= 0.0002)
def train_ae(net, trainloader, NUM_EPOCHS, displayEvery=5):
# train_loss = []
for epoch in range(NUM_EPOCHS):
running_loss = 0.0
for data in trainloader:
img, _ = data # no need for the labels
img = img.to(device)
optimizer.zero_grad()
outputs = net(img)
loss = criterion(outputs, img)
loss.backward()
optimizer.step()
running_loss += loss.item()
loss = running_loss / len(trainloader)
if (epoch+1)%displayEvery==0:
print('- Epoch {} of {}, Train Loss: {:.3f}'.format(epoch+1, NUM_EPOCHS, loss))
def get_outliers(net, trainloader):
net.eval()
running_loss = np.array([])
for data in trainloader:
img, _ = data # no need for the labels
img = img.to(device)
optimizer.zero_grad()
outputs = net(img)
loss = criterion(outputs, img)
running_loss = np.append(running_loss,np.array(loss.item()))
highest6 = np.argsort(running_loss)[::-1][:25] #top 25 highest reconstructed loss
return highest6
DataParallel(
(module): SimpleAutoEncoder(
(encoder): Sequential(
(0): Conv2d(3, 128, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1))
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1))
(7): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU(inplace=True)
(9): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1))
(10): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(11): ReLU(inplace=True)
)
(bottleneck): Sequential(
(0): Flatten(start_dim=1, end_dim=-1)
(1): Linear(in_features=9216, out_features=384, bias=True)
(2): BatchNorm1d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): Linear(in_features=384, out_features=9216, bias=True)
(5): BatchNorm1d(9216, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace=True)
(7): Unflatten(dim=1, unflattened_size=(16, 24, 24))
)
(decoder): Sequential(
(0): ConvTranspose2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): ConvTranspose2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): ConvTranspose2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
(7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU(inplace=True)
(9): ConvTranspose2d(128, 3, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1))
(10): ReLU(inplace=True)
)
)
)
# Since I am making an autoencoder for outlier analysis of our training data, higher batch
# size is generally better (as I want my autoencoder to overfit on my training data, and also it trains faster)
train_dataloader_autoencode = DataLoader(TensorDataset(train_data.type('torch.FloatTensor'),train_label.type('torch.FloatTensor')),batch_size=2048,shuffle=True)
# Single batch loader to compare RMSE of each image in each class
def single_loader():
return DataLoader(TensorDataset(train_data.type('torch.FloatTensor'),train_label.type('torch.FloatTensor')),batch_size=1,shuffle=False)
train_ae(model, train_dataloader_autoencode, 65)
# Remove the dataloader from memory
del train_dataloader_autoencode
- Epoch 5 of 65, Train Loss: 0.053 - Epoch 10 of 65, Train Loss: 0.033 - Epoch 15 of 65, Train Loss: 0.026 - Epoch 20 of 65, Train Loss: 0.023 - Epoch 25 of 65, Train Loss: 0.020 - Epoch 30 of 65, Train Loss: 0.019 - Epoch 35 of 65, Train Loss: 0.018 - Epoch 40 of 65, Train Loss: 0.017 - Epoch 45 of 65, Train Loss: 0.016 - Epoch 50 of 65, Train Loss: 0.015 - Epoch 55 of 65, Train Loss: 0.014 - Epoch 60 of 65, Train Loss: 0.013 - Epoch 65 of 65, Train Loss: 0.014
fig, ax = plt.subplots(5, 5, figsize=(25, 15))
for i in range(5):
for j in range(5):
label_idx = train_label.type(torch.LongTensor).numpy()[outleirs_idx[i*5+j]]
images = train_data_np[outleirs_idx[i*5+j]]
subplot = ax[i, j]
subplot.axis("off")
subplot.imshow(images)
subplot.set_title(f"Label: {class_labels[label_idx]}")
# yes...I mispelled outliers
outleirs_idx = get_outliers(model,single_loader()).astype(int)
Here we hit the limitation of using just 1 autoencoder for outlier analysis(explained earlier)...We can tell that aquatic animals (of superclasses) are generally the common outliers as it looks very different from the general features/image of our dataset. However, I would not say that these aquatic animals are outliers even though they theoretically have the biggest RMSE (difference in terms features/image comapred to our general image), as the obvious reason why these aquatic animals are outliers are because they are extremely different by being blue compared to the average image (not blue). To overcome this problem it would make more sense to train an individual autoencoder on every superclass and then observe the outliers (since image being blue for an aquatic animal is obviously common and not an outlier for aquatic animal class itself). However, using just 1 general autoencoder for outlier analysis still tells us that blu-ish images are generally the most different and stand out the most from our general image. In fact, its also a surprise to see keyboard and orchid being shown here and they can easily be misidenitifed as aquatic animal super or subclasses. This gives me insight on what type of data augmentation to use too.
Making an autoencoder for every superclass is too time consuming...Therefore it is not done here
| Split | Size |
|---|---|
| Training | 40K |
| Validation | 10K |
| Testing | 10k |
Only the training set should be augmented. The only augmentation that I can apply on the validation/test set would be Normalization (either min-max or z-scaled), note that the values used for Normalization must also be from the training dataset only, this is to minimize data leakage since we should not know any values from the testing data (the difference should be quite eligible but good practice).
The testing data cannot be used for evaluation at any point only until the final evaluation. These remove any decision making made by using testing data hence removing overfitting of the testing data.
No information from testing data can be used for training or validation data, and no information from validation data should be used for training data.
print('Fine data min:',train_data.min())
print('Fine data max:',train_data.max())
print(train_data.shape,'\n\n')
Fine data min: tensor(0.) Fine data max: tensor(1.) torch.Size([50000, 3, 32, 32])
min-max scaling:
${x'}$ = $\frac{x - min(x)}{max(x) - min(x)}$
simplifying to:
${x'}$ = $\frac{x}{1}$
...seems like my dataset is already normalized though...so no normalization is needed
Swap some axes, and change data type to torch tensor
def npToTensor(np_):
np_ = np.swapaxes(np_,1,-1)
np_ = np.swapaxes(np_,2,-1)
return torch.Tensor(np_)
More data preparation here:
traindata_input = train_data[:40000]
valdata_input = train_data[40000:]
# One hot encoded labels
traindata_label = F.one_hot(train_label.type('torch.LongTensor')[:40000],100)
valdata_label = F.one_hot(train_label.type('torch.LongTensor')[40000:],100)
train_label = F.one_hot(train_label.type('torch.LongTensor'),100)
# del train_label
test_label = F.one_hot(test_label.type('torch.LongTensor'),100)
display(traindata_input.shape)
display(valdata_input.shape)
# One hot encoded labels
train_label_coarse = F.one_hot(torch.Tensor(train_label_coarse).type('torch.LongTensor'),20)
trainlabel_coarse = train_label_coarse[:40000]
vallabel_coarse = train_label_coarse[40000:]
test_label_coarse = F.one_hot(torch.Tensor(test_label_coarse).type('torch.LongTensor'),20)
torch.Size([40000, 3, 32, 32])
torch.Size([10000, 3, 32, 32])
CutMix source code from: Github - idoonet, Original Paper: [Yun, 2019]
Code is reference from the given link. However, I added a function to augment images in different batch sizes.
def rand_bbox(size, lamb):
""" Generate random bounding box
Args:
- size: [width, breadth] of the bounding box
- lamb: (lambda) cut ratio parameter, sampled from Beta distribution
Returns:
- Bounding box
"""
W = size[0]
H = size[1]
cut_rat = np.sqrt(1. - lamb)
cut_w = np.int32(W * cut_rat)
cut_h = np.int32(H * cut_rat)
# uniform
cx = np.random.randint(W)
cy = np.random.randint(H)
bbx1 = np.clip(cx - cut_w // 2, 0, W)
bby1 = np.clip(cy - cut_h // 2, 0, H)
bbx2 = np.clip(cx + cut_w // 2, 0, W)
bby2 = np.clip(cy + cut_h // 2, 0, H)
return bbx1, bby1, bbx2, bby2
def generate_cutmix_image(image_batch, image_batch_labels, beta=1):
""" Generate a CutMix augmented image from a batch
Args:
- image_batch: a batch of input images
- image_batch_labels: labels corresponding to the image batch
- beta: a parameter of Beta distribution.
Returns:
- CutMix image batch, updated labels
"""
# generate mixed sample
lam = np.random.beta(beta, beta)
rand_index = np.random.permutation(len(image_batch))
target_a = image_batch_labels
target_b = image_batch_labels[rand_index]
bbx1, bby1, bbx2, bby2 = rand_bbox(image_batch[0].shape, lam)
image_batch_updated = image_batch.copy()
image_batch_updated[:, bbx1:bbx2, bby1:bby2, :] = image_batch[rand_index, bbx1:bbx2, bby1:bby2, :]
# adjust lambda to exactly match pixel ratio
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (image_batch.shape[1] * image_batch.shape[2]))
label = target_a * lam + target_b * (1. - lam)
return image_batch_updated, label
def CutMixBatches(data,label,batch_size=1,classNum=100,beta= 1):
stack = np.empty([0,32,32,3])
stack_label = np.empty([0,classNum])
for i in range(int(len(data)/batch_size)):
start_slice = i*batch_size
end_slice = (i+1)*batch_size
data_batch, label_batch = generate_cutmix_image(data[start_slice:end_slice], label[start_slice:end_slice], beta=beta)
stack = np.concatenate((stack,data_batch))
stack_label = np.concatenate((stack_label,label_batch))
stack = npToTensor(stack)
stack_label = torch.Tensor(stack_label)
return stack, stack_label
def DataAugment(train_set,composeAug,batch_size=1):
train_set = train_set * 255
train_set = train_set.type(torch.ByteTensor)
stack = torch.tensor([], dtype=torch.float)
for i in range(int(len(train_set)/batch_size)):
stack = torch.cat((stack,composeAug((train_set[i*batch_size:(i+1)*batch_size]))))
stack = stack / 255
return stack
# Augmentation Composition from scratch
imageColorTrans = transforms.Compose([transforms.RandomHorizontalFlip(0.825),
transforms.RandomRotation((-10,10)),
transforms.ColorJitter(brightness=0.5, contrast=0.4, saturation=0.1, hue=0.1),
transforms.RandomPerspective(0.0825,0.25),
transforms.RandomSolarize(0.985, p=0.15),
])
imageColorTrans2 = transforms.Compose([transforms.RandomRotation((-10,10)),
transforms.ColorJitter(brightness=0.5, contrast=0.4, saturation=0.1, hue=0.1),
transforms.RandomPerspective(0.0825,0.25),
transforms.RandomSolarize(0.985, p=0.15),
])
# Augmentation to compare to my scratch CIFAR10 policy chosen
imageAutoAugment = transforms.Compose([transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
transforms.RandomHorizontalFlip(0.825)])
imageAutoAugment2 = transforms.Compose([transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10)])
imageAutoAugmentImageNet = transforms.Compose([transforms.AutoAugment(transforms.AutoAugmentPolicy.IMAGENET),
transforms.RandomHorizontalFlip(0.825)])
imageAutoAugmentImageNet2 = transforms.Compose([transforms.AutoAugment(transforms.AutoAugmentPolicy.IMAGENET)])
imageAugMix = transforms.Compose([transforms.AugMix(),
transforms.RandomHorizontalFlip(0.825)])
imageAugMix2 = transforms.Compose([transforms.AugMix()])
flipImage = transforms.Compose([transforms.RandomHorizontalFlip(1)])
Pre-replicate my data
traindata_input_flipped = DataAugment(traindata_input,flipImage,batch_size=20000)
traindata_input_replicated = torch.cat((traindata_input,traindata_input_flipped))
traindata_label_replicated = torch.cat((traindata_label,traindata_label))
def displayAugment(augmentCompose=imageColorTrans):
train_data_test = DataAugment(traindata_input,augmentCompose,batch_size=1000)
train_data_test_np = train_data_test.numpy()
train_data_test_np = np.swapaxes(train_data_test_np,1,-1)
train_data_test_np = np.swapaxes(train_data_test_np,1,2)
fig, ax = plt.subplots(10, 5, figsize=(15, 30))
for i in range(10):
for j in range(5):
label = class_labels[i*5+j]
images = train_data_test_np[np.squeeze(traindata_label.argmax(-1) == (i)*5+j)]
subplot = ax[i, j]
subplot.axis("off")
subplot.imshow(images[random.randint(0,350)])
subplot.set_title(f"Label: {label}")
displayAugment()
ImageNet Augmentation Policy
displayAugment(imageAutoAugmentImageNet2)
CIFAR10 Augmmentation Policy
displayAugment(imageAutoAugment2)
AugMix Augmentation
displayAugment(imageAugMix2)
Let's visualize our CutMix data next
image_batch = img4np(traindata_input)
image_batch_labels = traindata_label.numpy()
image_batch_updated, image_batch_labels_updated = CutMixBatches(image_batch,image_batch_labels,batch_size=2000,beta= 1)
image_batch_updated = img4np(image_batch_updated)
image_batch_labels_updated = image_batch_labels_updated.numpy()
randomIdx = random.sample(range(40000), 4)
# Show original images
print("Original Images")
for i in range(2):
for j in range(2):
plt.subplot(2,2,2*i+j+1)
plt.imshow(image_batch[randomIdx[2*i+j]])
plt.title(f'Label: {class_labels[image_batch_labels[randomIdx[2*i+j]].argmax(-1)]}')
plt.tight_layout()
plt.show()
# Show CutMix images
print("CutMix Images")
for i in range(2):
for j in range(2):
plt.subplot(2,2,2*i+j+1)
plt.imshow(image_batch_updated[randomIdx[2*i+j]])
highestIdx = image_batch_labels_updated[randomIdx[2*i+j]].argmax(-1)
secondhighestIdx = image_batch_labels_updated[randomIdx[2*i+j]].argsort()[-2]
plt.title(f'{class_labels[highestIdx]}({image_batch_labels_updated[randomIdx[2*i+j]][highestIdx]:.2f})/{class_labels[secondhighestIdx]}({image_batch_labels_updated[randomIdx[2*i+j]][secondhighestIdx]:.2f})')
plt.tight_layout()
plt.show()
# Print labels
print('Original labels:')
print(image_batch_labels)
print('Updated labels')
print(image_batch_labels_updated)
del image_batch_updated, image_batch, image_batch_labels, image_batch_labels_updated
Original Images
CutMix Images
Original labels: [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [1 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]] Updated labels [[0. 0. 0. ... 0. 0. 0. ] [0. 0. 0. ... 0. 0. 0. ] [0.68359375 0. 0. ... 0. 0. 0. ] ... [0. 0. 0. ... 0. 0. 0. ] [0. 0. 0. ... 0. 0. 0. ] [0. 0. 0. ... 0. 0. 0. ]]
As we can observed, CutMix is working as there are 2 labels after our CutMix augmentation and their label probability are also updated accordingly. The setting chosen is the default setting in the original paper of CutMix. Of course more advanced versions of CutMix has been produced and would be good to explore as well, such as MixUp & FMix.
Original Paper: [Yun, 2019]
# Also replicates the training data and shuffle for us
def replicateData1(traindata=traindata_input, trainlabel=traindata_label,augCompose=imageColorTrans,augBatch=1250,beta=1,classNum_=100):
traindata2 = DataAugment(traindata,augCompose,batch_size=augBatch)
traindata3, adjustedLabel3 = CutMixBatches(img4np(traindata_input),trainlabel.numpy(),batch_size=augBatch,beta=beta,classNum=classNum_)
inputComb = torch.cat((traindata,traindata2,traindata3))
labelComb = torch.cat((trainlabel,trainlabel,adjustedLabel3))
# shuffle dataset
indices = np.arange(inputComb.shape[0])
np.random.shuffle(indices)
inputComb = inputComb[indices]
labelComb = labelComb[indices]
return inputComb, labelComb
# Also replicates the training data and shuffle for us
def replicateData2(traindata=traindata_input, trainlabel=traindata_label,augCompose=imageAutoAugment,augBatch=2500):
traindata2 = DataAugment(traindata,augCompose,batch_size=augBatch)
inputComb = torch.cat((traindata,traindata2))
labelComb = torch.cat((trainlabel,trainlabel))
# shuffle dataset
indices = np.arange(inputComb.shape[0])
np.random.shuffle(indices)
inputComb = inputComb[indices]
labelComb = labelComb[indices]
return inputComb, labelComb
# Also replicates the training data and shuffle for us
def replicateData3(traindata=traindata_input, trainlabel=traindata_label,augCompose1=imageColorTrans,augCompose2=imageAutoAugment,augBatch=2500):
traindata2 = DataAugment(traindata,augCompose1,batch_size=augBatch)
traindata3 = DataAugment(traindata,augCompose2,batch_size=augBatch)
inputComb = torch.cat((traindata,traindata2,traindata3))
labelComb = torch.cat((trainlabel,trainlabel,trainlabel))
# shuffle dataset
indices = np.arange(inputComb.shape[0])
np.random.shuffle(indices)
inputComb = inputComb[indices]
labelComb = labelComb[indices]
return inputComb, labelComb
def replicateData4(traindata=traindata_input, trainlabel=traindata_label,augCompose1=imageColorTrans,augCompose2=imageAutoAugment,augBatch=2500):
traindata2 = DataAugment(traindata,augCompose1,batch_size=augBatch)
traindata3 = DataAugment(traindata,augCompose2,batch_size=augBatch)
inputComb = torch.cat((traindata2,traindata3))
labelComb = torch.cat((trainlabel,trainlabel))
# shuffle dataset
indices = np.arange(inputComb.shape[0])
np.random.shuffle(indices)
inputComb = inputComb[indices]
labelComb = labelComb[indices]
return inputComb, labelComb
def replicateData5(trainlabel=traindata_label,augCompose1=imageColorTrans,augBatch=2500):
inputComb = DataAugment(traindata_input_replicated,augCompose1,batch_size=augBatch)
labelComb = torch.cat((trainlabel,trainlabel))
# shuffle dataset
indices = np.arange(inputComb.shape[0])
np.random.shuffle(indices)
inputComb = inputComb[indices]
labelComb = labelComb[indices]
return inputComb, labelComb
test_1, test_2 = replicateData1(augBatch=5000)
test_1 = img4np(test_1)
test_2 = test_2.numpy()
fig, ax = plt.subplots(20, 5, figsize=(15, 30))
for i in range(20):
for j in range(5):
label = class_labels[i*5+j]
images = test_1[np.squeeze(test_2.argmax(-1) == (i)*5+j)]
subplot = ax[i, j]
subplot.axis("off")
subplot.imshow(images[random.randint(0,350)])
subplot.set_title(f"Label: {label}")
del test_1, test_2
Collective dataset containing all augmentation. The same augmentation will be applied to the coarse dataset just different labels.
I will try something different from Part A as well. I will resize my image to the original default input size of the model instead of editing the model input size to 32x32. Even though, it takes a lot more computational power to resize and train a model with image of 224x224
def plot_history(train_loss, train_acc, val_loss, val_acc, epoch_his,title):
fig, (ax2, ax1) = plt.subplots(1, 2, figsize=(16, 6))
sns.lineplot(x=epoch_his,y=train_loss, ax=ax1, label='Train Loss')
sns.lineplot(x=epoch_his,y=val_loss, ax=ax1, label='Validation Loss')
sns.lineplot(x=epoch_his,y=train_acc, ax=ax2, label='Train Accuracy')
sns.lineplot(x=epoch_his,y=val_acc, ax=ax2, label='Test Accuracy')
ax1.title.set_text(f'Lowest val loss @{min(val_loss):.4f} @Epoch {epoch_his[val_loss.index(min(val_loss))]}')
ax2.title.set_text(f'Highest val acc @{max(val_acc):.4f} @Epoch {epoch_his[val_acc.index(max(val_acc))]}')
ax1.axvline(epoch_his[val_loss.index(min(val_loss))],color='r',ls='--',alpha=0.75)
ax2.axvline(epoch_his[val_acc.index(max(val_acc))],color='r',ls='--',alpha=0.75)
ax1.axhline(min(val_loss),color='r',ls='--',alpha=0.75)
ax2.axhline(max(val_acc),color='r',ls='--',alpha=0.75)
fig.suptitle(title)
plt.show()
def classification_rep(model,criterion,loader,train_loss, train_acc, val_loss, val_acc, epoch_his ,title, plot=True):
print(title)
if plot:
plot_history(train_loss, train_acc, val_loss, val_acc, epoch_his,title)
y_pred = []
y_true = []
model.eval()
for i, data in enumerate(loader):
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
y_pred.extend(outputs.argmax(-1).tolist())
y_true.extend(labels.argmax(-1).tolist())
model.train()
print(classification_report(y_true, y_pred,target_names=class_labels))
return
def keep_history(df,train_acc,train_loss,val_loss, val_acc, epoch_his,model_desc='NIL',desc='NIL'):
return pd.concat([df,pd.DataFrame([[min(val_loss),max(val_acc),min(train_loss),max(train_acc),epoch_his[val_acc.index(max(val_acc))],model_desc,desc]],columns=['Lowest Val Loss','Highest Val Acc','Lowest Train Loss','Highest Train Acc','Epoch (Highest Val Acc)', 'Model Description', 'Parameter'])])
hist_df = pd.DataFrame([],columns=['Lowest Val Loss','Highest Val Acc','Lowest Train Loss','Highest Train Acc','Epoch (Highest Val Acc)', 'Model Description', 'Parameter'])
class EarlyStopping():
def __init__(self, patience=5):
self.patience = patience
self.counter = 0
self.early_stop = False
self.lowest_val_loss = 99999999
self.highest_val_accuracy = 0
def __call__(self, val_acc, val_loss):
if val_loss >= self.lowest_val_loss and val_acc <= self.highest_val_accuracy:
self.counter +=1
else:
self.counter = 0
if val_acc > self.highest_val_accuracy:
self.highest_val_accuracy = val_acc
if val_loss < self.lowest_val_loss:
self.lowest_val_loss = val_loss
if self.counter >= self.patience:
self.early_stop = True
return self.early_stop
def reset(self):
self.counter = 0
self.lowest_val_loss = 99999999
self.highest_val_accuracy = 0
self.early_stop = False
def train(model, input_, label_ , optimizer, NUM_EPOCHS, criterion, valloader=None,earlystopper=None,showEpoch=True,scheduler=None,displayEvery=1,augmentEvery=None,augCompose=imageAutoAugment,augBatch=2000):
train_loss_his = []
train_accuracy_his = []
val_loss_his = []
val_accuracy_his = []
epoch_his = []
earlystopper.reset()
model.train()
for epoch in range(NUM_EPOCHS):
t0 = time.time()
total = 0
correct = 0
running_loss = 0.0
if augmentEvery != None:
if epoch % augmentEvery == 0:
input_rep, label_rep = replicateData2(traindata=input_,augCompose=augCompose,augBatch=augBatch,trainlabel=label_)
# The augmented dataset is already shuffled
loader = DataLoader(TensorDataset(input_rep.type('torch.FloatTensor'),label_rep.type('torch.FloatTensor')), shuffle=False, batch_size=BATCH_SIZE)
else:
loader = DataLoader(TensorDataset(input_.type('torch.FloatTensor'),label.type('torch.FloatTensor')), shuffle=True, batch_size=BATCH_SIZE)
for data in loader:
img, label = data
img = img.to(device)
label = label.to(device)
optimizer.zero_grad()
outputs = model(img)
loss = criterion(outputs, label)
loss.backward()
optimizer.step()
# _, predicted = torch.max(outputs.data, 1)
correct += (outputs.argmax(-1) == label.argmax(-1)).sum().item()
total += label.size(0)
running_loss += loss.item()
if scheduler is not None:
scheduler.step()
accuracy = 100 * correct / total
loss = running_loss / len(loader)
train_loss_his.append(loss)
train_accuracy_his.append(accuracy)
val_loss, val_accuracy = validate(model, criterion, valloader)
val_loss_his.append(val_loss)
val_accuracy_his.append(val_accuracy)
epoch_his.append(epoch+1)
if showEpoch and (epoch+1)%displayEvery == 0:
print(f'- [Epoch {epoch+1}/{NUM_EPOCHS}] | Train Loss: {loss:.3f}| Train Accuracy: {accuracy} | Val Loss: {val_loss:.3f} | Val Accuracy: {val_accuracy} | Est: {(time.time() - t0)*displayEvery:.2f}s')
if earlyStopper(val_accuracy,val_loss):
print(f'EarlyStopper triggered at epochs: {epoch+1} \n*No improvement to validation loss and accuracy could be seen for the past {earlyStopper.patience} epochs')
break
print(f'Highest Val Accuracy: {max(val_accuracy_his)} @ epoch {epoch_his[val_accuracy_his.index(max(val_accuracy_his))]} | Lowest Val Loss: {min(val_loss_his)} @ epoch {epoch_his[val_loss_his.index(min(val_loss_his))]}')
return train_loss_his, train_accuracy_his, val_loss_his, val_accuracy_his,epoch_his
def validate(model, criterion, loader=None):
correct = 0
total = 0
running_loss = []
model.eval()
for i, data in enumerate(loader):
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
# _, predicted = torch.max(outputs.data, 1)
correct += (outputs.argmax(-1) == labels.argmax(-1)).sum().item()
running_loss.append(loss.item())
total += labels.size(0)
mean_val_accuracy = (100 * correct / total)
mean_val_loss = sum(running_loss) / len(running_loss)
model.train()
return mean_val_loss, mean_val_accuracy
Layers inspired by the VGG16 architecture [VGG16 Architecture]
def simpleBlock(inp_channel, layers, maxpoolStride1 = False):
block = nn.Sequential()
for i in range(layers-1):
block.add_module(f'conv_layer_{i+1}', nn.Conv2d(inp_channel, inp_channel,kernel_size=3,stride=1,padding=1))
block.add_module(f'batch_norm_{i+1}',nn.BatchNorm2d(inp_channel))
block.add_module(f'ReLU_{i+1}',nn.ReLU(inplace=True))
if layers-1 > 0 and not maxpoolStride1:
block.add_module('max_pool2x2',nn.MaxPool2d(2, 2))
elif layers-1 > 0 and maxpoolStride1:
block.add_module('max_pool2x2',nn.MaxPool2d(2, 1))
return block
class SimpleNet(nn.Module):
def __init__(self,layer_config,inp,out,growthRate=2,fc_neurons=2048,classes=100):
super(SimpleNet, self).__init__()
# Implement the sequential module for feature extraction
self.start = nn.Sequential(nn.Conv2d(in_channels=inp, out_channels=out, kernel_size=5, stride=1, padding=1), nn.BatchNorm2d(out), nn.ReLU(inplace=True))
inp = out
out = out * growthRate
self.block1 = simpleBlock(inp,layer_config[0])
self.block2 = nn.Sequential(nn.Conv2d(in_channels=inp, out_channels=out, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out), nn.ReLU(inplace=True))
self.block2_add = simpleBlock(out, layer_config[1])
#Expand
inp = out
out = out * growthRate
self.block3 = nn.Sequential(nn.Conv2d(in_channels=inp, out_channels=out, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out), nn.ReLU(inplace=True))
self.block3_add = simpleBlock(out, layer_config[2])
#Expand
inp = out
out = out * growthRate
self.block4 = nn.Sequential(nn.Conv2d(in_channels=inp, out_channels=out, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out), nn.ReLU(inplace=True))
self.block4_add = simpleBlock(out, layer_config[3], True)
self.flatten = nn.Flatten()
self.classifier = nn.Sequential(
nn.Linear(out *2 * 2, fc_neurons), nn.BatchNorm1d(fc_neurons), nn.ReLU(inplace=True), nn.Dropout(p=0.5),
nn.Linear(fc_neurons, fc_neurons), nn.BatchNorm1d(fc_neurons), nn.ReLU(inplace=True),
nn.Linear(fc_neurons, classes)
)
# CrossEntropyLoss already have softmax activation
def forward(self, x):
# Features extraction
x = self.start(x)
x = self.block1(x)
x = self.block2(x)
x = self.block2_add(x)
x = self.block3(x)
x = self.block3_add(x)
x = self.block4(x)
x = self.block4_add(x)
x = self.flatten(x)
x = self.classifier(x)
return x

[Image source: Xie, 2016]
More explanation:
Greatly reduced the original source code by:
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
#3x3 convolution with padding
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=False,
dilation=dilation,
)
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
#1x1 convolution - frequently used for downsample/bottleneck
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion: int = 1
def __init__(
self,
inplanes: int, #input
planes: int, #output
stride: int = 1,
downsample: Optional[nn.Module] = None, #optional conv layer (see forward to see where downsample takes place)
groups: int = 1, #cardinality - horizontal group
base_width: int = 64, #'width' per cardinality - no. of filter each cardinality
dilation: int = 1, #"inflation" rate of the kernel by inserting holes between the kernel element
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion: int = 4
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__()
#use custom norm_layer if exist, otherwise just use BatchNorm2d
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.0)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
#if downsample layer exist, downsample x (aka identity)
if self.downsample is not None:
identity = self.downsample(x)
#add identity to output after running the layers in Bottleneck
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(
self,
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
num_classes: int = 100,
zero_init_residual: bool = True,
groups: int = 1,
width_per_group: int = 64,
replace_stride_with_dilation: Optional[List[bool]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError(
"replace_stride_with_dilation should be None "
f"or a 3-element tuple, got {replace_stride_with_dilation}"
)
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck) and m.bn3.weight is not None:
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
def _make_layer(
self,
block: Type[Union[BasicBlock, Bottleneck]],
planes: int,
blocks: int,
stride: int = 1,
dilate: bool = False,
) -> nn.Sequential:
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(
block(
self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
)
)
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(
self.inplanes,
planes,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation,
norm_layer=norm_layer,
)
)
return nn.Sequential(*layers)
def _forward_impl(self, x: torch.Tensor) -> torch.Tensor:
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self._forward_impl(x)

[Image source: Wu, 2021]
This model has achieved state-of-the-art performance in image classification tasks when used with the benchmark datasets like ImageNet.
Greatly reduced the original source code by:
#CoAtNet basic 3x3 conv 'structure'
def conv_3x3_bn(inp, oup, image_size, downsample=False):
stride = 1 if downsample == False else 2
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.GELU()
#Gaussian Error Linear Units(GELU) - similar to RELU, for resources see (https://pytorch.org/docs/stable/generated/torch.nn.GELU.html)
)
class PreNorm(nn.Module):
def __init__(self, dim, fn, norm):
super().__init__()
self.norm = norm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
#squeeze-excitation - to downsample the spartial size of the global attention and get a relative global attention
class SE(nn.Module):
def __init__(self, inp, oup, expansion=0.25):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(oup, int(inp * expansion), bias=False),
nn.GELU(),
nn.Linear(int(inp * expansion), oup, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class MBConv(nn.Module):
def __init__(self, inp, oup, image_size, downsample=False, expansion=4):
super().__init__()
self.downsample = downsample
stride = 1 if self.downsample == False else 2
hidden_dim = int(inp * expansion)
if self.downsample:
self.pool = nn.MaxPool2d(3, 2, 1)
self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)
if expansion == 1:
self.conv = nn.Sequential(
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.GELU(),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
else:
self.conv = nn.Sequential(
# pw
# down-sample in the first conv
nn.Conv2d(inp, hidden_dim, 1, stride, 0, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.GELU(),
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1,
groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.GELU(),
SE(inp, hidden_dim),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
self.conv = PreNorm(inp, self.conv, nn.BatchNorm2d)
def forward(self, x):
if self.downsample:
return self.proj(self.pool(x)) + self.conv(x)
else:
return x + self.conv(x)
class Attention(nn.Module):
def __init__(self, inp, oup, image_size, heads=8, dim_head=32, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == inp)
self.ih, self.iw = image_size
self.heads = heads
self.scale = dim_head ** -0.5
# parameter table of relative position bias
self.relative_bias_table = nn.Parameter(
torch.zeros((2 * self.ih - 1) * (2 * self.iw - 1), heads))
coords = torch.meshgrid((torch.arange(self.ih), torch.arange(self.iw)))
coords = torch.flatten(torch.stack(coords), 1)
relative_coords = coords[:, :, None] - coords[:, None, :]
relative_coords[0] += self.ih - 1
relative_coords[1] += self.iw - 1
relative_coords[0] *= 2 * self.iw - 1
relative_coords = rearrange(relative_coords, 'c h w -> h w c')
relative_index = relative_coords.sum(-1).flatten().unsqueeze(1)
self.register_buffer("relative_index", relative_index)
self.attend = nn.Softmax(dim=-1)
self.to_qkv = nn.Linear(inp, inner_dim * 3, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, oup),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(
t, 'b n (h d) -> b h n d', h=self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
# Use "gather" for more efficiency on GPUs
relative_bias = self.relative_bias_table.gather(
0, self.relative_index.repeat(1, self.heads))
relative_bias = rearrange(
relative_bias, '(h w) c -> 1 c h w', h=self.ih*self.iw, w=self.ih*self.iw)
dots = dots + relative_bias
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out
class Transformer(nn.Module):
def __init__(self, inp, oup, image_size, heads=8, dim_head=32, downsample=False, dropout=0.):
super().__init__()
hidden_dim = int(inp * 4)
self.ih, self.iw = image_size
self.downsample = downsample
if self.downsample:
self.pool1 = nn.MaxPool2d(3, 2, 1)
self.pool2 = nn.MaxPool2d(3, 2, 1)
self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)
self.attn = Attention(inp, oup, image_size, heads, dim_head, dropout)
self.ff = FeedForward(oup, hidden_dim, dropout)
self.attn = nn.Sequential(
Rearrange('b c ih iw -> b (ih iw) c'),
PreNorm(inp, self.attn, nn.LayerNorm),
Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
)
self.ff = nn.Sequential(
Rearrange('b c ih iw -> b (ih iw) c'),
PreNorm(oup, self.ff, nn.LayerNorm),
Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
)
def forward(self, x):
if self.downsample:
x = self.proj(self.pool1(x)) + self.attn(self.pool2(x))
else:
x = x + self.attn(x)
x = x + self.ff(x)
return x
class CoAtNet(nn.Module):
def __init__(self, image_size, in_channels, num_blocks, channels, num_classes=1000, block_types=['C', 'C', 'T', 'T']):
super().__init__()
ih, iw = image_size
block = {'C': MBConv, 'T': Transformer}
self.s0 = self._make_layer(
conv_3x3_bn, in_channels, channels[0], num_blocks[0], (ih // 2, iw // 2))
self.s1 = self._make_layer(
block[block_types[0]], channels[0], channels[1], num_blocks[1], (ih // 4, iw // 4))
self.s2 = self._make_layer(
block[block_types[1]], channels[1], channels[2], num_blocks[2], (ih // 8, iw // 8))
self.s3 = self._make_layer(
block[block_types[2]], channels[2], channels[3], num_blocks[3], (ih // 16, iw // 16))
self.s4 = self._make_layer(
block[block_types[3]], channels[3], channels[4], num_blocks[4], (ih // 32, iw // 32))
self.pool = nn.AvgPool2d(ih // 32, 1)
self.fc = nn.Linear(channels[-1], num_classes, bias=False)
def forward(self, x):
x = self.s0(x)
x = self.s1(x)
x = self.s2(x)
x = self.s3(x)
x = self.s4(x)
x = self.pool(x).view(-1, x.shape[1])
x = self.fc(x)
return x
def _make_layer(self, block, inp, oup, depth, image_size):
layers = nn.ModuleList([])
for i in range(depth):
if i == 0:
layers.append(block(inp, oup, image_size, downsample=True))
else:
layers.append(block(oup, oup, image_size))
return nn.Sequential(*layers)
DenseNet is similar to ResNet where its residual is also passed down as layers are increased but what makes DenseNet different from ResNet would be the connectivity pattern its residual. What does this mean?

Basically the idea behind DenseNet is that the input of the next layer is a combination of the input and residual from all the pervious layers
This connectivity pattern has shown improvements over ResNet architecture.
Code referenced from https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py
Edits made to original source code by:
class _DenseLayer(nn.Module):
def __init__(
self, num_input_features: int, growth_rate: int, bn_size: int, drop_rate: float, memory_efficient: bool = False
) -> None:
super().__init__()
self.norm1 = nn.BatchNorm2d(num_input_features)
self.relu1 = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)
self.norm2 = nn.BatchNorm2d(bn_size * growth_rate)
self.relu2 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)
self.drop_rate = float(drop_rate)
self.memory_efficient = memory_efficient
def bn_function(self, inputs: List[torch.Tensor]) -> torch.Tensor:
concated_features = torch.cat(inputs, 1)
bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features))) # noqa: T484
return bottleneck_output
# todo: rewrite when torchscript supports any
def any_requires_grad(self, input: List[torch.Tensor]) -> bool:
for tensor in input:
if tensor.requires_grad:
return True
return False
@torch.jit.unused # noqa: T484
def call_checkpoint_bottleneck(self, input: List[torch.Tensor]) -> torch.Tensor:
def closure(*inputs):
return self.bn_function(inputs)
return cp.checkpoint(closure, *input)
@torch.jit._overload_method # noqa: F811
def forward(self, input: List[torch.Tensor]) -> torch.Tensor: # noqa: F811
pass
@torch.jit._overload_method # noqa: F811
def forward(self, input: torch.Tensor) -> torch.Tensor: # noqa: F811
pass
# torchscript does not yet support *args, so we overload method
# allowing it to take either a List[Tensor] or single Tensor
def forward(self, input: torch.Tensor) -> torch.Tensor: # noqa: F811
if isinstance(input, torch.Tensor):
prev_features = [input]
else:
prev_features = input
if self.memory_efficient and self.any_requires_grad(prev_features):
if torch.jit.is_scripting():
raise Exception("Memory Efficient not supported in JIT")
bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
else:
bottleneck_output = self.bn_function(prev_features)
new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
if self.drop_rate > 0:
new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
return new_features
class _DenseBlock(nn.ModuleDict):
_version = 2
def __init__(
self,
num_layers: int,
num_input_features: int,
bn_size: int,
growth_rate: int,
drop_rate: float,
memory_efficient: bool = False,
) -> None:
super().__init__()
for i in range(num_layers):
layer = _DenseLayer(
num_input_features + i * growth_rate,
growth_rate=growth_rate,
bn_size=bn_size,
drop_rate=drop_rate,
memory_efficient=memory_efficient,
)
self.add_module("denselayer%d" % (i + 1), layer)
def forward(self, init_features: torch.Tensor) -> torch.Tensor:
features = [init_features]
for name, layer in self.items():
new_features = layer(features)
features.append(new_features)
return torch.cat(features, 1)
class _Transition(nn.Sequential):
def __init__(self, num_input_features: int, num_output_features: int) -> None:
super().__init__()
self.norm = nn.BatchNorm2d(num_input_features)
self.relu = nn.ReLU(inplace=True)
self.conv = nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)
self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
class DenseNet(nn.Module):
r"""Densenet-BC model class, based on
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_.
Args:
growth_rate (int) - how many filters to add each layer (`k` in paper)
block_config (list of 4 ints) - how many layers in each pooling block
num_init_features (int) - the number of filters to learn in the first convolution layer
bn_size (int) - multiplicative factor for number of bottle neck layers
(i.e. bn_size * k features in the bottleneck layer)
drop_rate (float) - dropout rate after each dense layer
num_classes (int) - number of classification classes
memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
"""
def __init__(
self,
growth_rate: int = 32,
block_config: Tuple[int, int, int, int] = (6, 12, 24, 16),
num_init_features: int = 64,
bn_size: int = 4,
drop_rate: float = 0,
num_classes: int = 100,
memory_efficient: bool = False,
) -> None:
super().__init__()
# First convolution
self.features = nn.Sequential(
OrderedDict(
[
("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
("norm0", nn.BatchNorm2d(num_init_features)),
("relu0", nn.ReLU(inplace=True)),
("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
]
)
)
# Each denseblock
num_features = num_init_features
for i, num_layers in enumerate(block_config):
block = _DenseBlock(
num_layers=num_layers,
num_input_features=num_features,
bn_size=bn_size,
growth_rate=growth_rate,
drop_rate=drop_rate,
memory_efficient=memory_efficient,
)
self.features.add_module("denseblock%d" % (i + 1), block)
num_features = num_features + num_layers * growth_rate
if i != len(block_config) - 1:
trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
self.features.add_module("transition%d" % (i + 1), trans)
num_features = num_features // 2
# Final batch norm
self.features.add_module("norm5", nn.BatchNorm2d(num_features))
# Linear layer
self.classifier = nn.Linear(num_features, num_classes)
# Official init from torch repo.
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
features = self.features(x)
out = F.relu(features, inplace=True)
out = F.adaptive_avg_pool2d(out, (1, 1))
out = torch.flatten(out, 1)
out = self.classifier(out)
return out
# init early stopper with patience of 18
earlyStopper = EarlyStopping(18)
Initializing models for fine labeled dataset
# Baseline VGG inspired model
print('\nSimpleNet')
Baseline = SimpleNet([2,3,4,4],3,32,2,2048,100)
summary(Baseline,(3,32,32))
# ResNet architectures
ResNet34 = ResNet(Bottleneck,[3,4,6,3]) #
print('\nResNet34 - Medium')
summary(ResNet34,(3,32,32))
ResNeXtCustom1 = ResNet(Bottleneck,[3,4,6,3] ,groups=32,width_per_group=4) #
print('\nResNeXt Custom 2 -> ResNet34 + cardinality(32), 4 "width" per cardinality - High Complexity')
summary(ResNeXtCustom1,(3,32,32))
# Attention architectures
CoAtNetCustom1 = CoAtNet((32, 32), 3, [2, 2, 6, 14, 2], [64, 96, 192, 384, 768], num_classes=100)
print('\nCoAtNet Simple')
summary(CoAtNetCustom1,(3,32,32))
# DenseNet architectures
DenseNetCustom1 = DenseNet(32, (6, 12, 48, 32), 64, num_classes=100)
print('\nDenseNet - Complex')
summary(DenseNetCustom1,(3,32,32))
SimpleNet ========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== ├─Sequential: 1-1 [-1, 32, 30, 30] -- | └─Conv2d: 2-1 [-1, 32, 30, 30] 2,432 | └─BatchNorm2d: 2-2 [-1, 32, 30, 30] 64 | └─ReLU: 2-3 [-1, 32, 30, 30] -- ├─Sequential: 1-2 [-1, 32, 15, 15] -- | └─Conv2d: 2-4 [-1, 32, 30, 30] 9,248 | └─BatchNorm2d: 2-5 [-1, 32, 30, 30] 64 | └─ReLU: 2-6 [-1, 32, 30, 30] -- | └─MaxPool2d: 2-7 [-1, 32, 15, 15] -- ├─Sequential: 1-3 [-1, 64, 15, 15] -- | └─Conv2d: 2-8 [-1, 64, 15, 15] 18,496 | └─BatchNorm2d: 2-9 [-1, 64, 15, 15] 128 | └─ReLU: 2-10 [-1, 64, 15, 15] -- ├─Sequential: 1-4 [-1, 64, 7, 7] -- | └─Conv2d: 2-11 [-1, 64, 15, 15] 36,928 | └─BatchNorm2d: 2-12 [-1, 64, 15, 15] 128 | └─ReLU: 2-13 [-1, 64, 15, 15] -- | └─Conv2d: 2-14 [-1, 64, 15, 15] 36,928 | └─BatchNorm2d: 2-15 [-1, 64, 15, 15] 128 | └─ReLU: 2-16 [-1, 64, 15, 15] -- | └─MaxPool2d: 2-17 [-1, 64, 7, 7] -- ├─Sequential: 1-5 [-1, 128, 7, 7] -- | └─Conv2d: 2-18 [-1, 128, 7, 7] 73,856 | └─BatchNorm2d: 2-19 [-1, 128, 7, 7] 256 | └─ReLU: 2-20 [-1, 128, 7, 7] -- ├─Sequential: 1-6 [-1, 128, 3, 3] -- | └─Conv2d: 2-21 [-1, 128, 7, 7] 147,584 | └─BatchNorm2d: 2-22 [-1, 128, 7, 7] 256 | └─ReLU: 2-23 [-1, 128, 7, 7] -- | └─Conv2d: 2-24 [-1, 128, 7, 7] 147,584 | └─BatchNorm2d: 2-25 [-1, 128, 7, 7] 256 | └─ReLU: 2-26 [-1, 128, 7, 7] -- | └─Conv2d: 2-27 [-1, 128, 7, 7] 147,584 | └─BatchNorm2d: 2-28 [-1, 128, 7, 7] 256 | └─ReLU: 2-29 [-1, 128, 7, 7] -- | └─MaxPool2d: 2-30 [-1, 128, 3, 3] -- ├─Sequential: 1-7 [-1, 256, 3, 3] -- | └─Conv2d: 2-31 [-1, 256, 3, 3] 295,168 | └─BatchNorm2d: 2-32 [-1, 256, 3, 3] 512 | └─ReLU: 2-33 [-1, 256, 3, 3] -- ├─Sequential: 1-8 [-1, 256, 2, 2] -- | └─Conv2d: 2-34 [-1, 256, 3, 3] 590,080 | └─BatchNorm2d: 2-35 [-1, 256, 3, 3] 512 | └─ReLU: 2-36 [-1, 256, 3, 3] -- | └─Conv2d: 2-37 [-1, 256, 3, 3] 590,080 | └─BatchNorm2d: 2-38 [-1, 256, 3, 3] 512 | └─ReLU: 2-39 [-1, 256, 3, 3] -- | └─Conv2d: 2-40 [-1, 256, 3, 3] 590,080 | └─BatchNorm2d: 2-41 [-1, 256, 3, 3] 512 | └─ReLU: 2-42 [-1, 256, 3, 3] -- | └─MaxPool2d: 2-43 [-1, 256, 2, 2] -- ├─Flatten: 1-9 [-1, 1024] -- ├─Sequential: 1-10 [-1, 100] -- | └─Linear: 2-44 [-1, 2048] 2,099,200 | └─BatchNorm1d: 2-45 [-1, 2048] 4,096 | └─ReLU: 2-46 [-1, 2048] -- | └─Dropout: 2-47 [-1, 2048] -- | └─Linear: 2-48 [-1, 2048] 4,196,352 | └─BatchNorm1d: 2-49 [-1, 2048] 4,096 | └─ReLU: 2-50 [-1, 2048] -- | └─Linear: 2-51 [-1, 100] 204,900 ========================================================================================== Total params: 9,198,276 Trainable params: 9,198,276 Non-trainable params: 0 Total mult-adds (M): 90.75 ========================================================================================== Input size (MB): 0.01 Forward/backward pass size (MB): 2.12 Params size (MB): 35.09 Estimated Total Size (MB): 37.23 ========================================================================================== ResNet34 - Medium ========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== ├─Conv2d: 1-1 [-1, 64, 16, 16] 9,408 ├─BatchNorm2d: 1-2 [-1, 64, 16, 16] 128 ├─ReLU: 1-3 [-1, 64, 16, 16] -- ├─MaxPool2d: 1-4 [-1, 64, 8, 8] -- ├─Sequential: 1-5 [-1, 256, 8, 8] -- | └─Bottleneck: 2-1 [-1, 256, 8, 8] -- | | └─Conv2d: 3-1 [-1, 64, 8, 8] 4,096 | | └─BatchNorm2d: 3-2 [-1, 64, 8, 8] 128 | | └─ReLU: 3-3 [-1, 64, 8, 8] -- | | └─Conv2d: 3-4 [-1, 64, 8, 8] 36,864 | | └─BatchNorm2d: 3-5 [-1, 64, 8, 8] 128 | | └─ReLU: 3-6 [-1, 64, 8, 8] -- | | └─Conv2d: 3-7 [-1, 256, 8, 8] 16,384 | | └─BatchNorm2d: 3-8 [-1, 256, 8, 8] 512 | | └─Sequential: 3-9 [-1, 256, 8, 8] 16,896 | | └─ReLU: 3-10 [-1, 256, 8, 8] -- | └─Bottleneck: 2-2 [-1, 256, 8, 8] -- | | └─Conv2d: 3-11 [-1, 64, 8, 8] 16,384 | | └─BatchNorm2d: 3-12 [-1, 64, 8, 8] 128 | | └─ReLU: 3-13 [-1, 64, 8, 8] -- | | └─Conv2d: 3-14 [-1, 64, 8, 8] 36,864 | | └─BatchNorm2d: 3-15 [-1, 64, 8, 8] 128 | | └─ReLU: 3-16 [-1, 64, 8, 8] -- | | └─Conv2d: 3-17 [-1, 256, 8, 8] 16,384 | | └─BatchNorm2d: 3-18 [-1, 256, 8, 8] 512 | | └─ReLU: 3-19 [-1, 256, 8, 8] -- | └─Bottleneck: 2-3 [-1, 256, 8, 8] -- | | └─Conv2d: 3-20 [-1, 64, 8, 8] 16,384 | | └─BatchNorm2d: 3-21 [-1, 64, 8, 8] 128 | | └─ReLU: 3-22 [-1, 64, 8, 8] -- | | └─Conv2d: 3-23 [-1, 64, 8, 8] 36,864 | | └─BatchNorm2d: 3-24 [-1, 64, 8, 8] 128 | | └─ReLU: 3-25 [-1, 64, 8, 8] -- | | └─Conv2d: 3-26 [-1, 256, 8, 8] 16,384 | | └─BatchNorm2d: 3-27 [-1, 256, 8, 8] 512 | | └─ReLU: 3-28 [-1, 256, 8, 8] -- ├─Sequential: 1-6 [-1, 512, 4, 4] -- | └─Bottleneck: 2-4 [-1, 512, 4, 4] -- | | └─Conv2d: 3-29 [-1, 128, 8, 8] 32,768 | | └─BatchNorm2d: 3-30 [-1, 128, 8, 8] 256 | | └─ReLU: 3-31 [-1, 128, 8, 8] -- | | └─Conv2d: 3-32 [-1, 128, 4, 4] 147,456 | | └─BatchNorm2d: 3-33 [-1, 128, 4, 4] 256 | | └─ReLU: 3-34 [-1, 128, 4, 4] -- | | └─Conv2d: 3-35 [-1, 512, 4, 4] 65,536 | | └─BatchNorm2d: 3-36 [-1, 512, 4, 4] 1,024 | | └─Sequential: 3-37 [-1, 512, 4, 4] 132,096 | | └─ReLU: 3-38 [-1, 512, 4, 4] -- | └─Bottleneck: 2-5 [-1, 512, 4, 4] -- | | └─Conv2d: 3-39 [-1, 128, 4, 4] 65,536 | | └─BatchNorm2d: 3-40 [-1, 128, 4, 4] 256 | | └─ReLU: 3-41 [-1, 128, 4, 4] -- | | └─Conv2d: 3-42 [-1, 128, 4, 4] 147,456 | | └─BatchNorm2d: 3-43 [-1, 128, 4, 4] 256 | | └─ReLU: 3-44 [-1, 128, 4, 4] -- | | └─Conv2d: 3-45 [-1, 512, 4, 4] 65,536 | | └─BatchNorm2d: 3-46 [-1, 512, 4, 4] 1,024 | | └─ReLU: 3-47 [-1, 512, 4, 4] -- | └─Bottleneck: 2-6 [-1, 512, 4, 4] -- | | └─Conv2d: 3-48 [-1, 128, 4, 4] 65,536 | | └─BatchNorm2d: 3-49 [-1, 128, 4, 4] 256 | | └─ReLU: 3-50 [-1, 128, 4, 4] -- | | └─Conv2d: 3-51 [-1, 128, 4, 4] 147,456 | | └─BatchNorm2d: 3-52 [-1, 128, 4, 4] 256 | | └─ReLU: 3-53 [-1, 128, 4, 4] -- | | └─Conv2d: 3-54 [-1, 512, 4, 4] 65,536 | | └─BatchNorm2d: 3-55 [-1, 512, 4, 4] 1,024 | | └─ReLU: 3-56 [-1, 512, 4, 4] -- | └─Bottleneck: 2-7 [-1, 512, 4, 4] -- | | └─Conv2d: 3-57 [-1, 128, 4, 4] 65,536 | | └─BatchNorm2d: 3-58 [-1, 128, 4, 4] 256 | | └─ReLU: 3-59 [-1, 128, 4, 4] -- | | └─Conv2d: 3-60 [-1, 128, 4, 4] 147,456 | | └─BatchNorm2d: 3-61 [-1, 128, 4, 4] 256 | | └─ReLU: 3-62 [-1, 128, 4, 4] -- | | └─Conv2d: 3-63 [-1, 512, 4, 4] 65,536 | | └─BatchNorm2d: 3-64 [-1, 512, 4, 4] 1,024 | | └─ReLU: 3-65 [-1, 512, 4, 4] -- ├─Sequential: 1-7 [-1, 1024, 2, 2] -- | └─Bottleneck: 2-8 [-1, 1024, 2, 2] -- | | └─Conv2d: 3-66 [-1, 256, 4, 4] 131,072 | | └─BatchNorm2d: 3-67 [-1, 256, 4, 4] 512 | | └─ReLU: 3-68 [-1, 256, 4, 4] -- | | └─Conv2d: 3-69 [-1, 256, 2, 2] 589,824 | | └─BatchNorm2d: 3-70 [-1, 256, 2, 2] 512 | | └─ReLU: 3-71 [-1, 256, 2, 2] -- | | └─Conv2d: 3-72 [-1, 1024, 2, 2] 262,144 | | └─BatchNorm2d: 3-73 [-1, 1024, 2, 2] 2,048 | | └─Sequential: 3-74 [-1, 1024, 2, 2] 526,336 | | └─ReLU: 3-75 [-1, 1024, 2, 2] -- | └─Bottleneck: 2-9 [-1, 1024, 2, 2] -- | | └─Conv2d: 3-76 [-1, 256, 2, 2] 262,144 | | └─BatchNorm2d: 3-77 [-1, 256, 2, 2] 512 | | └─ReLU: 3-78 [-1, 256, 2, 2] -- | | └─Conv2d: 3-79 [-1, 256, 2, 2] 589,824 | | └─BatchNorm2d: 3-80 [-1, 256, 2, 2] 512 | | └─ReLU: 3-81 [-1, 256, 2, 2] -- | | └─Conv2d: 3-82 [-1, 1024, 2, 2] 262,144 | | └─BatchNorm2d: 3-83 [-1, 1024, 2, 2] 2,048 | | └─ReLU: 3-84 [-1, 1024, 2, 2] -- | └─Bottleneck: 2-10 [-1, 1024, 2, 2] -- | | └─Conv2d: 3-85 [-1, 256, 2, 2] 262,144 | | └─BatchNorm2d: 3-86 [-1, 256, 2, 2] 512 | | └─ReLU: 3-87 [-1, 256, 2, 2] -- | | └─Conv2d: 3-88 [-1, 256, 2, 2] 589,824 | | └─BatchNorm2d: 3-89 [-1, 256, 2, 2] 512 | | └─ReLU: 3-90 [-1, 256, 2, 2] -- | | └─Conv2d: 3-91 [-1, 1024, 2, 2] 262,144 | | └─BatchNorm2d: 3-92 [-1, 1024, 2, 2] 2,048 | | └─ReLU: 3-93 [-1, 1024, 2, 2] -- | └─Bottleneck: 2-11 [-1, 1024, 2, 2] -- | | └─Conv2d: 3-94 [-1, 256, 2, 2] 262,144 | | └─BatchNorm2d: 3-95 [-1, 256, 2, 2] 512 | | └─ReLU: 3-96 [-1, 256, 2, 2] -- | | └─Conv2d: 3-97 [-1, 256, 2, 2] 589,824 | | └─BatchNorm2d: 3-98 [-1, 256, 2, 2] 512 | | └─ReLU: 3-99 [-1, 256, 2, 2] -- | | └─Conv2d: 3-100 [-1, 1024, 2, 2] 262,144 | | └─BatchNorm2d: 3-101 [-1, 1024, 2, 2] 2,048 | | └─ReLU: 3-102 [-1, 1024, 2, 2] -- | └─Bottleneck: 2-12 [-1, 1024, 2, 2] -- | | └─Conv2d: 3-103 [-1, 256, 2, 2] 262,144 | | └─BatchNorm2d: 3-104 [-1, 256, 2, 2] 512 | | └─ReLU: 3-105 [-1, 256, 2, 2] -- | | └─Conv2d: 3-106 [-1, 256, 2, 2] 589,824 | | └─BatchNorm2d: 3-107 [-1, 256, 2, 2] 512 | | └─ReLU: 3-108 [-1, 256, 2, 2] -- | | └─Conv2d: 3-109 [-1, 1024, 2, 2] 262,144 | | └─BatchNorm2d: 3-110 [-1, 1024, 2, 2] 2,048 | | └─ReLU: 3-111 [-1, 1024, 2, 2] -- | └─Bottleneck: 2-13 [-1, 1024, 2, 2] -- | | └─Conv2d: 3-112 [-1, 256, 2, 2] 262,144 | | └─BatchNorm2d: 3-113 [-1, 256, 2, 2] 512 | | └─ReLU: 3-114 [-1, 256, 2, 2] -- | | └─Conv2d: 3-115 [-1, 256, 2, 2] 589,824 | | └─BatchNorm2d: 3-116 [-1, 256, 2, 2] 512 | | └─ReLU: 3-117 [-1, 256, 2, 2] -- | | └─Conv2d: 3-118 [-1, 1024, 2, 2] 262,144 | | └─BatchNorm2d: 3-119 [-1, 1024, 2, 2] 2,048 | | └─ReLU: 3-120 [-1, 1024, 2, 2] -- ├─Sequential: 1-8 [-1, 2048, 1, 1] -- | └─Bottleneck: 2-14 [-1, 2048, 1, 1] -- | | └─Conv2d: 3-121 [-1, 512, 2, 2] 524,288 | | └─BatchNorm2d: 3-122 [-1, 512, 2, 2] 1,024 | | └─ReLU: 3-123 [-1, 512, 2, 2] -- | | └─Conv2d: 3-124 [-1, 512, 1, 1] 2,359,296 | | └─BatchNorm2d: 3-125 [-1, 512, 1, 1] 1,024 | | └─ReLU: 3-126 [-1, 512, 1, 1] -- | | └─Conv2d: 3-127 [-1, 2048, 1, 1] 1,048,576 | | └─BatchNorm2d: 3-128 [-1, 2048, 1, 1] 4,096 | | └─Sequential: 3-129 [-1, 2048, 1, 1] 2,101,248 | | └─ReLU: 3-130 [-1, 2048, 1, 1] -- | └─Bottleneck: 2-15 [-1, 2048, 1, 1] -- | | └─Conv2d: 3-131 [-1, 512, 1, 1] 1,048,576 | | └─BatchNorm2d: 3-132 [-1, 512, 1, 1] 1,024 | | └─ReLU: 3-133 [-1, 512, 1, 1] -- | | └─Conv2d: 3-134 [-1, 512, 1, 1] 2,359,296 | | └─BatchNorm2d: 3-135 [-1, 512, 1, 1] 1,024 | | └─ReLU: 3-136 [-1, 512, 1, 1] -- | | └─Conv2d: 3-137 [-1, 2048, 1, 1] 1,048,576 | | └─BatchNorm2d: 3-138 [-1, 2048, 1, 1] 4,096 | | └─ReLU: 3-139 [-1, 2048, 1, 1] -- | └─Bottleneck: 2-16 [-1, 2048, 1, 1] -- | | └─Conv2d: 3-140 [-1, 512, 1, 1] 1,048,576 | | └─BatchNorm2d: 3-141 [-1, 512, 1, 1] 1,024 | | └─ReLU: 3-142 [-1, 512, 1, 1] -- | | └─Conv2d: 3-143 [-1, 512, 1, 1] 2,359,296 | | └─BatchNorm2d: 3-144 [-1, 512, 1, 1] 1,024 | | └─ReLU: 3-145 [-1, 512, 1, 1] -- | | └─Conv2d: 3-146 [-1, 2048, 1, 1] 1,048,576 | | └─BatchNorm2d: 3-147 [-1, 2048, 1, 1] 4,096 | | └─ReLU: 3-148 [-1, 2048, 1, 1] -- ├─AdaptiveAvgPool2d: 1-9 [-1, 2048, 1, 1] -- ├─Linear: 1-10 [-1, 100] 204,900 ========================================================================================== Total params: 23,712,932 Trainable params: 23,712,932 Non-trainable params: 0 Total mult-adds (M): 133.36 ========================================================================================== Input size (MB): 0.01 Forward/backward pass size (MB): 3.46 Params size (MB): 90.46 Estimated Total Size (MB): 93.93 ========================================================================================== ResNeXt Custom 2 -> ResNet34 + cardinality(32), 4 "width" per cardinality - High Complexity ========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== ├─Conv2d: 1-1 [-1, 64, 16, 16] 9,408 ├─BatchNorm2d: 1-2 [-1, 64, 16, 16] 128 ├─ReLU: 1-3 [-1, 64, 16, 16] -- ├─MaxPool2d: 1-4 [-1, 64, 8, 8] -- ├─Sequential: 1-5 [-1, 256, 8, 8] -- | └─Bottleneck: 2-1 [-1, 256, 8, 8] -- | | └─Conv2d: 3-1 [-1, 128, 8, 8] 8,192 | | └─BatchNorm2d: 3-2 [-1, 128, 8, 8] 256 | | └─ReLU: 3-3 [-1, 128, 8, 8] -- | | └─Conv2d: 3-4 [-1, 128, 8, 8] 4,608 | | └─BatchNorm2d: 3-5 [-1, 128, 8, 8] 256 | | └─ReLU: 3-6 [-1, 128, 8, 8] -- | | └─Conv2d: 3-7 [-1, 256, 8, 8] 32,768 | | └─BatchNorm2d: 3-8 [-1, 256, 8, 8] 512 | | └─Sequential: 3-9 [-1, 256, 8, 8] 16,896 | | └─ReLU: 3-10 [-1, 256, 8, 8] -- | └─Bottleneck: 2-2 [-1, 256, 8, 8] -- | | └─Conv2d: 3-11 [-1, 128, 8, 8] 32,768 | | └─BatchNorm2d: 3-12 [-1, 128, 8, 8] 256 | | └─ReLU: 3-13 [-1, 128, 8, 8] -- | | └─Conv2d: 3-14 [-1, 128, 8, 8] 4,608 | | └─BatchNorm2d: 3-15 [-1, 128, 8, 8] 256 | | └─ReLU: 3-16 [-1, 128, 8, 8] -- | | └─Conv2d: 3-17 [-1, 256, 8, 8] 32,768 | | └─BatchNorm2d: 3-18 [-1, 256, 8, 8] 512 | | └─ReLU: 3-19 [-1, 256, 8, 8] -- | └─Bottleneck: 2-3 [-1, 256, 8, 8] -- | | └─Conv2d: 3-20 [-1, 128, 8, 8] 32,768 | | └─BatchNorm2d: 3-21 [-1, 128, 8, 8] 256 | | └─ReLU: 3-22 [-1, 128, 8, 8] -- | | └─Conv2d: 3-23 [-1, 128, 8, 8] 4,608 | | └─BatchNorm2d: 3-24 [-1, 128, 8, 8] 256 | | └─ReLU: 3-25 [-1, 128, 8, 8] -- | | └─Conv2d: 3-26 [-1, 256, 8, 8] 32,768 | | └─BatchNorm2d: 3-27 [-1, 256, 8, 8] 512 | | └─ReLU: 3-28 [-1, 256, 8, 8] -- ├─Sequential: 1-6 [-1, 512, 4, 4] -- | └─Bottleneck: 2-4 [-1, 512, 4, 4] -- | | └─Conv2d: 3-29 [-1, 256, 8, 8] 65,536 | | └─BatchNorm2d: 3-30 [-1, 256, 8, 8] 512 | | └─ReLU: 3-31 [-1, 256, 8, 8] -- | | └─Conv2d: 3-32 [-1, 256, 4, 4] 18,432 | | └─BatchNorm2d: 3-33 [-1, 256, 4, 4] 512 | | └─ReLU: 3-34 [-1, 256, 4, 4] -- | | └─Conv2d: 3-35 [-1, 512, 4, 4] 131,072 | | └─BatchNorm2d: 3-36 [-1, 512, 4, 4] 1,024 | | └─Sequential: 3-37 [-1, 512, 4, 4] 132,096 | | └─ReLU: 3-38 [-1, 512, 4, 4] -- | └─Bottleneck: 2-5 [-1, 512, 4, 4] -- | | └─Conv2d: 3-39 [-1, 256, 4, 4] 131,072 | | └─BatchNorm2d: 3-40 [-1, 256, 4, 4] 512 | | └─ReLU: 3-41 [-1, 256, 4, 4] -- | | └─Conv2d: 3-42 [-1, 256, 4, 4] 18,432 | | └─BatchNorm2d: 3-43 [-1, 256, 4, 4] 512 | | └─ReLU: 3-44 [-1, 256, 4, 4] -- | | └─Conv2d: 3-45 [-1, 512, 4, 4] 131,072 | | └─BatchNorm2d: 3-46 [-1, 512, 4, 4] 1,024 | | └─ReLU: 3-47 [-1, 512, 4, 4] -- | └─Bottleneck: 2-6 [-1, 512, 4, 4] -- | | └─Conv2d: 3-48 [-1, 256, 4, 4] 131,072 | | └─BatchNorm2d: 3-49 [-1, 256, 4, 4] 512 | | └─ReLU: 3-50 [-1, 256, 4, 4] -- | | └─Conv2d: 3-51 [-1, 256, 4, 4] 18,432 | | └─BatchNorm2d: 3-52 [-1, 256, 4, 4] 512 | | └─ReLU: 3-53 [-1, 256, 4, 4] -- | | └─Conv2d: 3-54 [-1, 512, 4, 4] 131,072 | | └─BatchNorm2d: 3-55 [-1, 512, 4, 4] 1,024 | | └─ReLU: 3-56 [-1, 512, 4, 4] -- | └─Bottleneck: 2-7 [-1, 512, 4, 4] -- | | └─Conv2d: 3-57 [-1, 256, 4, 4] 131,072 | | └─BatchNorm2d: 3-58 [-1, 256, 4, 4] 512 | | └─ReLU: 3-59 [-1, 256, 4, 4] -- | | └─Conv2d: 3-60 [-1, 256, 4, 4] 18,432 | | └─BatchNorm2d: 3-61 [-1, 256, 4, 4] 512 | | └─ReLU: 3-62 [-1, 256, 4, 4] -- | | └─Conv2d: 3-63 [-1, 512, 4, 4] 131,072 | | └─BatchNorm2d: 3-64 [-1, 512, 4, 4] 1,024 | | └─ReLU: 3-65 [-1, 512, 4, 4] -- ├─Sequential: 1-7 [-1, 1024, 2, 2] -- | └─Bottleneck: 2-8 [-1, 1024, 2, 2] -- | | └─Conv2d: 3-66 [-1, 512, 4, 4] 262,144 | | └─BatchNorm2d: 3-67 [-1, 512, 4, 4] 1,024 | | └─ReLU: 3-68 [-1, 512, 4, 4] -- | | └─Conv2d: 3-69 [-1, 512, 2, 2] 73,728 | | └─BatchNorm2d: 3-70 [-1, 512, 2, 2] 1,024 | | └─ReLU: 3-71 [-1, 512, 2, 2] -- | | └─Conv2d: 3-72 [-1, 1024, 2, 2] 524,288 | | └─BatchNorm2d: 3-73 [-1, 1024, 2, 2] 2,048 | | └─Sequential: 3-74 [-1, 1024, 2, 2] 526,336 | | └─ReLU: 3-75 [-1, 1024, 2, 2] -- | └─Bottleneck: 2-9 [-1, 1024, 2, 2] -- | | └─Conv2d: 3-76 [-1, 512, 2, 2] 524,288 | | └─BatchNorm2d: 3-77 [-1, 512, 2, 2] 1,024 | | └─ReLU: 3-78 [-1, 512, 2, 2] -- | | └─Conv2d: 3-79 [-1, 512, 2, 2] 73,728 | | └─BatchNorm2d: 3-80 [-1, 512, 2, 2] 1,024 | | └─ReLU: 3-81 [-1, 512, 2, 2] -- | | └─Conv2d: 3-82 [-1, 1024, 2, 2] 524,288 | | └─BatchNorm2d: 3-83 [-1, 1024, 2, 2] 2,048 | | └─ReLU: 3-84 [-1, 1024, 2, 2] -- | └─Bottleneck: 2-10 [-1, 1024, 2, 2] -- | | └─Conv2d: 3-85 [-1, 512, 2, 2] 524,288 | | └─BatchNorm2d: 3-86 [-1, 512, 2, 2] 1,024 | | └─ReLU: 3-87 [-1, 512, 2, 2] -- | | └─Conv2d: 3-88 [-1, 512, 2, 2] 73,728 | | └─BatchNorm2d: 3-89 [-1, 512, 2, 2] 1,024 | | └─ReLU: 3-90 [-1, 512, 2, 2] -- | | └─Conv2d: 3-91 [-1, 1024, 2, 2] 524,288 | | └─BatchNorm2d: 3-92 [-1, 1024, 2, 2] 2,048 | | └─ReLU: 3-93 [-1, 1024, 2, 2] -- | └─Bottleneck: 2-11 [-1, 1024, 2, 2] -- | | └─Conv2d: 3-94 [-1, 512, 2, 2] 524,288 | | └─BatchNorm2d: 3-95 [-1, 512, 2, 2] 1,024 | | └─ReLU: 3-96 [-1, 512, 2, 2] -- | | └─Conv2d: 3-97 [-1, 512, 2, 2] 73,728 | | └─BatchNorm2d: 3-98 [-1, 512, 2, 2] 1,024 | | └─ReLU: 3-99 [-1, 512, 2, 2] -- | | └─Conv2d: 3-100 [-1, 1024, 2, 2] 524,288 | | └─BatchNorm2d: 3-101 [-1, 1024, 2, 2] 2,048 | | └─ReLU: 3-102 [-1, 1024, 2, 2] -- | └─Bottleneck: 2-12 [-1, 1024, 2, 2] -- | | └─Conv2d: 3-103 [-1, 512, 2, 2] 524,288 | | └─BatchNorm2d: 3-104 [-1, 512, 2, 2] 1,024 | | └─ReLU: 3-105 [-1, 512, 2, 2] -- | | └─Conv2d: 3-106 [-1, 512, 2, 2] 73,728 | | └─BatchNorm2d: 3-107 [-1, 512, 2, 2] 1,024 | | └─ReLU: 3-108 [-1, 512, 2, 2] -- | | └─Conv2d: 3-109 [-1, 1024, 2, 2] 524,288 | | └─BatchNorm2d: 3-110 [-1, 1024, 2, 2] 2,048 | | └─ReLU: 3-111 [-1, 1024, 2, 2] -- | └─Bottleneck: 2-13 [-1, 1024, 2, 2] -- | | └─Conv2d: 3-112 [-1, 512, 2, 2] 524,288 | | └─BatchNorm2d: 3-113 [-1, 512, 2, 2] 1,024 | | └─ReLU: 3-114 [-1, 512, 2, 2] -- | | └─Conv2d: 3-115 [-1, 512, 2, 2] 73,728 | | └─BatchNorm2d: 3-116 [-1, 512, 2, 2] 1,024 | | └─ReLU: 3-117 [-1, 512, 2, 2] -- | | └─Conv2d: 3-118 [-1, 1024, 2, 2] 524,288 | | └─BatchNorm2d: 3-119 [-1, 1024, 2, 2] 2,048 | | └─ReLU: 3-120 [-1, 1024, 2, 2] -- ├─Sequential: 1-8 [-1, 2048, 1, 1] -- | └─Bottleneck: 2-14 [-1, 2048, 1, 1] -- | | └─Conv2d: 3-121 [-1, 1024, 2, 2] 1,048,576 | | └─BatchNorm2d: 3-122 [-1, 1024, 2, 2] 2,048 | | └─ReLU: 3-123 [-1, 1024, 2, 2] -- | | └─Conv2d: 3-124 [-1, 1024, 1, 1] 294,912 | | └─BatchNorm2d: 3-125 [-1, 1024, 1, 1] 2,048 | | └─ReLU: 3-126 [-1, 1024, 1, 1] -- | | └─Conv2d: 3-127 [-1, 2048, 1, 1] 2,097,152 | | └─BatchNorm2d: 3-128 [-1, 2048, 1, 1] 4,096 | | └─Sequential: 3-129 [-1, 2048, 1, 1] 2,101,248 | | └─ReLU: 3-130 [-1, 2048, 1, 1] -- | └─Bottleneck: 2-15 [-1, 2048, 1, 1] -- | | └─Conv2d: 3-131 [-1, 1024, 1, 1] 2,097,152 | | └─BatchNorm2d: 3-132 [-1, 1024, 1, 1] 2,048 | | └─ReLU: 3-133 [-1, 1024, 1, 1] -- | | └─Conv2d: 3-134 [-1, 1024, 1, 1] 294,912 | | └─BatchNorm2d: 3-135 [-1, 1024, 1, 1] 2,048 | | └─ReLU: 3-136 [-1, 1024, 1, 1] -- | | └─Conv2d: 3-137 [-1, 2048, 1, 1] 2,097,152 | | └─BatchNorm2d: 3-138 [-1, 2048, 1, 1] 4,096 | | └─ReLU: 3-139 [-1, 2048, 1, 1] -- | └─Bottleneck: 2-16 [-1, 2048, 1, 1] -- | | └─Conv2d: 3-140 [-1, 1024, 1, 1] 2,097,152 | | └─BatchNorm2d: 3-141 [-1, 1024, 1, 1] 2,048 | | └─ReLU: 3-142 [-1, 1024, 1, 1] -- | | └─Conv2d: 3-143 [-1, 1024, 1, 1] 294,912 | | └─BatchNorm2d: 3-144 [-1, 1024, 1, 1] 2,048 | | └─ReLU: 3-145 [-1, 1024, 1, 1] -- | | └─Conv2d: 3-146 [-1, 2048, 1, 1] 2,097,152 | | └─BatchNorm2d: 3-147 [-1, 2048, 1, 1] 4,096 | | └─ReLU: 3-148 [-1, 2048, 1, 1] -- ├─AdaptiveAvgPool2d: 1-9 [-1, 2048, 1, 1] -- ├─Linear: 1-10 [-1, 100] 204,900 ========================================================================================== Total params: 23,184,804 Trainable params: 23,184,804 Non-trainable params: 0 Total mult-adds (M): 135.18 ========================================================================================== Input size (MB): 0.01 Forward/backward pass size (MB): 4.49 Params size (MB): 88.44 Estimated Total Size (MB): 92.94 ==========================================================================================
/opt/conda/lib/python3.9/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /opt/conda/conda-bld/pytorch_1666643016022/work/aten/src/ATen/native/TensorShape.cpp:3190.) return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
CoAtNet Simple ========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== ├─Sequential: 1-1 [-1, 64, 16, 16] -- | └─Sequential: 2-1 [-1, 64, 16, 16] -- | | └─Conv2d: 3-1 [-1, 64, 16, 16] 1,728 | | └─BatchNorm2d: 3-2 [-1, 64, 16, 16] 128 | | └─GELU: 3-3 [-1, 64, 16, 16] -- | └─Sequential: 2-2 [-1, 64, 16, 16] -- | | └─Conv2d: 3-4 [-1, 64, 16, 16] 36,864 | | └─BatchNorm2d: 3-5 [-1, 64, 16, 16] 128 | | └─GELU: 3-6 [-1, 64, 16, 16] -- ├─Sequential: 1-2 [-1, 96, 8, 8] -- | └─MBConv: 2-3 [-1, 96, 8, 8] -- | | └─MaxPool2d: 3-7 [-1, 64, 8, 8] -- | | └─Conv2d: 3-8 [-1, 96, 8, 8] 6,144 | | └─PreNorm: 3-9 [-1, 96, 8, 8] 52,800 | └─MBConv: 2-4 [-1, 96, 8, 8] -- | | └─PreNorm: 3-10 [-1, 96, 8, 8] 97,536 ├─Sequential: 1-3 [-1, 192, 4, 4] -- | └─MBConv: 2-5 [-1, 192, 4, 4] -- | | └─MaxPool2d: 3-11 [-1, 96, 4, 4] -- | | └─Conv2d: 3-12 [-1, 192, 4, 4] 18,432 | | └─PreNorm: 3-13 [-1, 192, 4, 4] 134,592 | └─MBConv: 2-6 [-1, 192, 4, 4] -- | | └─PreNorm: 3-14 [-1, 192, 4, 4] 379,392 | └─MBConv: 2-7 [-1, 192, 4, 4] -- | | └─PreNorm: 3-15 [-1, 192, 4, 4] 379,392 | └─MBConv: 2-8 [-1, 192, 4, 4] -- | | └─PreNorm: 3-16 [-1, 192, 4, 4] 379,392 | └─MBConv: 2-9 [-1, 192, 4, 4] -- | | └─PreNorm: 3-17 [-1, 192, 4, 4] 379,392 | └─MBConv: 2-10 [-1, 192, 4, 4] -- | | └─PreNorm: 3-18 [-1, 192, 4, 4] 379,392 ├─Sequential: 1-4 [-1, 384, 2, 2] -- | └─Transformer: 2-11 [-1, 384, 2, 2] -- | | └─MaxPool2d: 3-19 [-1, 192, 2, 2] -- | | └─Conv2d: 3-20 [-1, 384, 2, 2] 73,728 | | └─MaxPool2d: 3-21 [-1, 192, 2, 2] -- | | └─Sequential: 3-22 [-1, 384, 2, 2] 246,600 | | └─Sequential: 3-23 [-1, 384, 2, 2] 591,744 | └─Transformer: 2-12 [-1, 384, 2, 2] -- | | └─Sequential: 3-24 [-1, 384, 2, 2] 394,440 | | └─Sequential: 3-25 [-1, 384, 2, 2] 1,182,336 | └─Transformer: 2-13 [-1, 384, 2, 2] -- | | └─Sequential: 3-26 [-1, 384, 2, 2] 394,440 | | └─Sequential: 3-27 [-1, 384, 2, 2] 1,182,336 | └─Transformer: 2-14 [-1, 384, 2, 2] -- | | └─Sequential: 3-28 [-1, 384, 2, 2] 394,440 | | └─Sequential: 3-29 [-1, 384, 2, 2] 1,182,336 | └─Transformer: 2-15 [-1, 384, 2, 2] -- | | └─Sequential: 3-30 [-1, 384, 2, 2] 394,440 | | └─Sequential: 3-31 [-1, 384, 2, 2] 1,182,336 | └─Transformer: 2-16 [-1, 384, 2, 2] -- | | └─Sequential: 3-32 [-1, 384, 2, 2] 394,440 | | └─Sequential: 3-33 [-1, 384, 2, 2] 1,182,336 | └─Transformer: 2-17 [-1, 384, 2, 2] -- | | └─Sequential: 3-34 [-1, 384, 2, 2] 394,440 | | └─Sequential: 3-35 [-1, 384, 2, 2] 1,182,336 | └─Transformer: 2-18 [-1, 384, 2, 2] -- | | └─Sequential: 3-36 [-1, 384, 2, 2] 394,440 | | └─Sequential: 3-37 [-1, 384, 2, 2] 1,182,336 | └─Transformer: 2-19 [-1, 384, 2, 2] -- | | └─Sequential: 3-38 [-1, 384, 2, 2] 394,440 | | └─Sequential: 3-39 [-1, 384, 2, 2] 1,182,336 | └─Transformer: 2-20 [-1, 384, 2, 2] -- | | └─Sequential: 3-40 [-1, 384, 2, 2] 394,440 | | └─Sequential: 3-41 [-1, 384, 2, 2] 1,182,336 | └─Transformer: 2-21 [-1, 384, 2, 2] -- | | └─Sequential: 3-42 [-1, 384, 2, 2] 394,440 | | └─Sequential: 3-43 [-1, 384, 2, 2] 1,182,336 | └─Transformer: 2-22 [-1, 384, 2, 2] -- | | └─Sequential: 3-44 [-1, 384, 2, 2] 394,440 | | └─Sequential: 3-45 [-1, 384, 2, 2] 1,182,336 | └─Transformer: 2-23 [-1, 384, 2, 2] -- | | └─Sequential: 3-46 [-1, 384, 2, 2] 394,440 | | └─Sequential: 3-47 [-1, 384, 2, 2] 1,182,336 | └─Transformer: 2-24 [-1, 384, 2, 2] -- | | └─Sequential: 3-48 [-1, 384, 2, 2] 394,440 | | └─Sequential: 3-49 [-1, 384, 2, 2] 1,182,336 ├─Sequential: 1-5 [-1, 768, 1, 1] -- | └─Transformer: 2-25 [-1, 768, 1, 1] -- | | └─MaxPool2d: 3-50 [-1, 384, 1, 1] -- | | └─Conv2d: 3-51 [-1, 768, 1, 1] 294,912 | | └─MaxPool2d: 3-52 [-1, 384, 1, 1] -- | | └─Sequential: 3-53 [-1, 768, 1, 1] 493,064 | | └─Sequential: 3-54 [-1, 768, 1, 1] 2,363,136 | └─Transformer: 2-26 [-1, 768, 1, 1] -- | | └─Sequential: 3-55 [-1, 768, 1, 1] 788,744 | | └─Sequential: 3-56 [-1, 768, 1, 1] 4,723,968 ├─AvgPool2d: 1-6 [-1, 768, 1, 1] -- ├─Linear: 1-7 [-1, 100] 76,800 ========================================================================================== Total params: 32,396,096 Trainable params: 32,396,096 Non-trainable params: 0 Total mult-adds (M): 139.38 ========================================================================================== Input size (MB): 0.01 Forward/backward pass size (MB): 0.92 Params size (MB): 123.58 Estimated Total Size (MB): 124.52 ========================================================================================== DenseNet - Complex ========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== ├─Sequential: 1-1 [-1, 1920, 1, 1] -- | └─Conv2d: 2-1 [-1, 64, 16, 16] 9,408 | └─BatchNorm2d: 2-2 [-1, 64, 16, 16] 128 | └─ReLU: 2-3 [-1, 64, 16, 16] -- | └─MaxPool2d: 2-4 [-1, 64, 8, 8] -- | └─_DenseBlock: 2-5 [-1, 256, 8, 8] -- | | └─_DenseLayer: 3-1 [-1, 32, 8, 8] 45,440 | | └─_DenseLayer: 3-2 [-1, 32, 8, 8] 49,600 | | └─_DenseLayer: 3-3 [-1, 32, 8, 8] 53,760 | | └─_DenseLayer: 3-4 [-1, 32, 8, 8] 57,920 | | └─_DenseLayer: 3-5 [-1, 32, 8, 8] 62,080 | | └─_DenseLayer: 3-6 [-1, 32, 8, 8] 66,240 | └─_Transition: 2-6 [-1, 128, 4, 4] -- | | └─BatchNorm2d: 3-7 [-1, 256, 8, 8] 512 | | └─ReLU: 3-8 [-1, 256, 8, 8] -- | | └─Conv2d: 3-9 [-1, 128, 8, 8] 32,768 | | └─AvgPool2d: 3-10 [-1, 128, 4, 4] -- | └─_DenseBlock: 2-7 [-1, 512, 4, 4] -- | | └─_DenseLayer: 3-11 [-1, 32, 4, 4] 53,760 | | └─_DenseLayer: 3-12 [-1, 32, 4, 4] 57,920 | | └─_DenseLayer: 3-13 [-1, 32, 4, 4] 62,080 | | └─_DenseLayer: 3-14 [-1, 32, 4, 4] 66,240 | | └─_DenseLayer: 3-15 [-1, 32, 4, 4] 70,400 | | └─_DenseLayer: 3-16 [-1, 32, 4, 4] 74,560 | | └─_DenseLayer: 3-17 [-1, 32, 4, 4] 78,720 | | └─_DenseLayer: 3-18 [-1, 32, 4, 4] 82,880 | | └─_DenseLayer: 3-19 [-1, 32, 4, 4] 87,040 | | └─_DenseLayer: 3-20 [-1, 32, 4, 4] 91,200 | | └─_DenseLayer: 3-21 [-1, 32, 4, 4] 95,360 | | └─_DenseLayer: 3-22 [-1, 32, 4, 4] 99,520 | └─_Transition: 2-8 [-1, 256, 2, 2] -- | | └─BatchNorm2d: 3-23 [-1, 512, 4, 4] 1,024 | | └─ReLU: 3-24 [-1, 512, 4, 4] -- | | └─Conv2d: 3-25 [-1, 256, 4, 4] 131,072 | | └─AvgPool2d: 3-26 [-1, 256, 2, 2] -- | └─_DenseBlock: 2-9 [-1, 1792, 2, 2] -- | | └─_DenseLayer: 3-27 [-1, 32, 2, 2] 70,400 | | └─_DenseLayer: 3-28 [-1, 32, 2, 2] 74,560 | | └─_DenseLayer: 3-29 [-1, 32, 2, 2] 78,720 | | └─_DenseLayer: 3-30 [-1, 32, 2, 2] 82,880 | | └─_DenseLayer: 3-31 [-1, 32, 2, 2] 87,040 | | └─_DenseLayer: 3-32 [-1, 32, 2, 2] 91,200 | | └─_DenseLayer: 3-33 [-1, 32, 2, 2] 95,360 | | └─_DenseLayer: 3-34 [-1, 32, 2, 2] 99,520 | | └─_DenseLayer: 3-35 [-1, 32, 2, 2] 103,680 | | └─_DenseLayer: 3-36 [-1, 32, 2, 2] 107,840 | | └─_DenseLayer: 3-37 [-1, 32, 2, 2] 112,000 | | └─_DenseLayer: 3-38 [-1, 32, 2, 2] 116,160 | | └─_DenseLayer: 3-39 [-1, 32, 2, 2] 120,320 | | └─_DenseLayer: 3-40 [-1, 32, 2, 2] 124,480 | | └─_DenseLayer: 3-41 [-1, 32, 2, 2] 128,640 | | └─_DenseLayer: 3-42 [-1, 32, 2, 2] 132,800 | | └─_DenseLayer: 3-43 [-1, 32, 2, 2] 136,960 | | └─_DenseLayer: 3-44 [-1, 32, 2, 2] 141,120 | | └─_DenseLayer: 3-45 [-1, 32, 2, 2] 145,280 | | └─_DenseLayer: 3-46 [-1, 32, 2, 2] 149,440 | | └─_DenseLayer: 3-47 [-1, 32, 2, 2] 153,600 | | └─_DenseLayer: 3-48 [-1, 32, 2, 2] 157,760 | | └─_DenseLayer: 3-49 [-1, 32, 2, 2] 161,920 | | └─_DenseLayer: 3-50 [-1, 32, 2, 2] 166,080 | | └─_DenseLayer: 3-51 [-1, 32, 2, 2] 170,240 | | └─_DenseLayer: 3-52 [-1, 32, 2, 2] 174,400 | | └─_DenseLayer: 3-53 [-1, 32, 2, 2] 178,560 | | └─_DenseLayer: 3-54 [-1, 32, 2, 2] 182,720 | | └─_DenseLayer: 3-55 [-1, 32, 2, 2] 186,880 | | └─_DenseLayer: 3-56 [-1, 32, 2, 2] 191,040 | | └─_DenseLayer: 3-57 [-1, 32, 2, 2] 195,200 | | └─_DenseLayer: 3-58 [-1, 32, 2, 2] 199,360 | | └─_DenseLayer: 3-59 [-1, 32, 2, 2] 203,520 | | └─_DenseLayer: 3-60 [-1, 32, 2, 2] 207,680 | | └─_DenseLayer: 3-61 [-1, 32, 2, 2] 211,840 | | └─_DenseLayer: 3-62 [-1, 32, 2, 2] 216,000 | | └─_DenseLayer: 3-63 [-1, 32, 2, 2] 220,160 | | └─_DenseLayer: 3-64 [-1, 32, 2, 2] 224,320 | | └─_DenseLayer: 3-65 [-1, 32, 2, 2] 228,480 | | └─_DenseLayer: 3-66 [-1, 32, 2, 2] 232,640 | | └─_DenseLayer: 3-67 [-1, 32, 2, 2] 236,800 | | └─_DenseLayer: 3-68 [-1, 32, 2, 2] 240,960 | | └─_DenseLayer: 3-69 [-1, 32, 2, 2] 245,120 | | └─_DenseLayer: 3-70 [-1, 32, 2, 2] 249,280 | | └─_DenseLayer: 3-71 [-1, 32, 2, 2] 253,440 | | └─_DenseLayer: 3-72 [-1, 32, 2, 2] 257,600 | | └─_DenseLayer: 3-73 [-1, 32, 2, 2] 261,760 | | └─_DenseLayer: 3-74 [-1, 32, 2, 2] 265,920 | └─_Transition: 2-10 [-1, 896, 1, 1] -- | | └─BatchNorm2d: 3-75 [-1, 1792, 2, 2] 3,584 | | └─ReLU: 3-76 [-1, 1792, 2, 2] -- | | └─Conv2d: 3-77 [-1, 896, 2, 2] 1,605,632 | | └─AvgPool2d: 3-78 [-1, 896, 1, 1] -- | └─_DenseBlock: 2-11 [-1, 1920, 1, 1] -- | | └─_DenseLayer: 3-79 [-1, 32, 1, 1] 153,600 | | └─_DenseLayer: 3-80 [-1, 32, 1, 1] 157,760 | | └─_DenseLayer: 3-81 [-1, 32, 1, 1] 161,920 | | └─_DenseLayer: 3-82 [-1, 32, 1, 1] 166,080 | | └─_DenseLayer: 3-83 [-1, 32, 1, 1] 170,240 | | └─_DenseLayer: 3-84 [-1, 32, 1, 1] 174,400 | | └─_DenseLayer: 3-85 [-1, 32, 1, 1] 178,560 | | └─_DenseLayer: 3-86 [-1, 32, 1, 1] 182,720 | | └─_DenseLayer: 3-87 [-1, 32, 1, 1] 186,880 | | └─_DenseLayer: 3-88 [-1, 32, 1, 1] 191,040 | | └─_DenseLayer: 3-89 [-1, 32, 1, 1] 195,200 | | └─_DenseLayer: 3-90 [-1, 32, 1, 1] 199,360 | | └─_DenseLayer: 3-91 [-1, 32, 1, 1] 203,520 | | └─_DenseLayer: 3-92 [-1, 32, 1, 1] 207,680 | | └─_DenseLayer: 3-93 [-1, 32, 1, 1] 211,840 | | └─_DenseLayer: 3-94 [-1, 32, 1, 1] 216,000 | | └─_DenseLayer: 3-95 [-1, 32, 1, 1] 220,160 | | └─_DenseLayer: 3-96 [-1, 32, 1, 1] 224,320 | | └─_DenseLayer: 3-97 [-1, 32, 1, 1] 228,480 | | └─_DenseLayer: 3-98 [-1, 32, 1, 1] 232,640 | | └─_DenseLayer: 3-99 [-1, 32, 1, 1] 236,800 | | └─_DenseLayer: 3-100 [-1, 32, 1, 1] 240,960 | | └─_DenseLayer: 3-101 [-1, 32, 1, 1] 245,120 | | └─_DenseLayer: 3-102 [-1, 32, 1, 1] 249,280 | | └─_DenseLayer: 3-103 [-1, 32, 1, 1] 253,440 | | └─_DenseLayer: 3-104 [-1, 32, 1, 1] 257,600 | | └─_DenseLayer: 3-105 [-1, 32, 1, 1] 261,760 | | └─_DenseLayer: 3-106 [-1, 32, 1, 1] 265,920 | | └─_DenseLayer: 3-107 [-1, 32, 1, 1] 270,080 | | └─_DenseLayer: 3-108 [-1, 32, 1, 1] 274,240 | | └─_DenseLayer: 3-109 [-1, 32, 1, 1] 278,400 | | └─_DenseLayer: 3-110 [-1, 32, 1, 1] 282,560 | └─BatchNorm2d: 2-12 [-1, 1920, 1, 1] 3,840 ├─Linear: 1-2 [-1, 100] 192,100 ========================================================================================== Total params: 18,285,028 Trainable params: 18,285,028 Non-trainable params: 0 Total mult-adds (M): 139.99 ========================================================================================== Input size (MB): 0.01 Forward/backward pass size (MB): 5.07 Params size (MB): 69.75 Estimated Total Size (MB): 74.83 ==========================================================================================
========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== ├─Sequential: 1-1 [-1, 1920, 1, 1] -- | └─Conv2d: 2-1 [-1, 64, 16, 16] 9,408 | └─BatchNorm2d: 2-2 [-1, 64, 16, 16] 128 | └─ReLU: 2-3 [-1, 64, 16, 16] -- | └─MaxPool2d: 2-4 [-1, 64, 8, 8] -- | └─_DenseBlock: 2-5 [-1, 256, 8, 8] -- | | └─_DenseLayer: 3-1 [-1, 32, 8, 8] 45,440 | | └─_DenseLayer: 3-2 [-1, 32, 8, 8] 49,600 | | └─_DenseLayer: 3-3 [-1, 32, 8, 8] 53,760 | | └─_DenseLayer: 3-4 [-1, 32, 8, 8] 57,920 | | └─_DenseLayer: 3-5 [-1, 32, 8, 8] 62,080 | | └─_DenseLayer: 3-6 [-1, 32, 8, 8] 66,240 | └─_Transition: 2-6 [-1, 128, 4, 4] -- | | └─BatchNorm2d: 3-7 [-1, 256, 8, 8] 512 | | └─ReLU: 3-8 [-1, 256, 8, 8] -- | | └─Conv2d: 3-9 [-1, 128, 8, 8] 32,768 | | └─AvgPool2d: 3-10 [-1, 128, 4, 4] -- | └─_DenseBlock: 2-7 [-1, 512, 4, 4] -- | | └─_DenseLayer: 3-11 [-1, 32, 4, 4] 53,760 | | └─_DenseLayer: 3-12 [-1, 32, 4, 4] 57,920 | | └─_DenseLayer: 3-13 [-1, 32, 4, 4] 62,080 | | └─_DenseLayer: 3-14 [-1, 32, 4, 4] 66,240 | | └─_DenseLayer: 3-15 [-1, 32, 4, 4] 70,400 | | └─_DenseLayer: 3-16 [-1, 32, 4, 4] 74,560 | | └─_DenseLayer: 3-17 [-1, 32, 4, 4] 78,720 | | └─_DenseLayer: 3-18 [-1, 32, 4, 4] 82,880 | | └─_DenseLayer: 3-19 [-1, 32, 4, 4] 87,040 | | └─_DenseLayer: 3-20 [-1, 32, 4, 4] 91,200 | | └─_DenseLayer: 3-21 [-1, 32, 4, 4] 95,360 | | └─_DenseLayer: 3-22 [-1, 32, 4, 4] 99,520 | └─_Transition: 2-8 [-1, 256, 2, 2] -- | | └─BatchNorm2d: 3-23 [-1, 512, 4, 4] 1,024 | | └─ReLU: 3-24 [-1, 512, 4, 4] -- | | └─Conv2d: 3-25 [-1, 256, 4, 4] 131,072 | | └─AvgPool2d: 3-26 [-1, 256, 2, 2] -- | └─_DenseBlock: 2-9 [-1, 1792, 2, 2] -- | | └─_DenseLayer: 3-27 [-1, 32, 2, 2] 70,400 | | └─_DenseLayer: 3-28 [-1, 32, 2, 2] 74,560 | | └─_DenseLayer: 3-29 [-1, 32, 2, 2] 78,720 | | └─_DenseLayer: 3-30 [-1, 32, 2, 2] 82,880 | | └─_DenseLayer: 3-31 [-1, 32, 2, 2] 87,040 | | └─_DenseLayer: 3-32 [-1, 32, 2, 2] 91,200 | | └─_DenseLayer: 3-33 [-1, 32, 2, 2] 95,360 | | └─_DenseLayer: 3-34 [-1, 32, 2, 2] 99,520 | | └─_DenseLayer: 3-35 [-1, 32, 2, 2] 103,680 | | └─_DenseLayer: 3-36 [-1, 32, 2, 2] 107,840 | | └─_DenseLayer: 3-37 [-1, 32, 2, 2] 112,000 | | └─_DenseLayer: 3-38 [-1, 32, 2, 2] 116,160 | | └─_DenseLayer: 3-39 [-1, 32, 2, 2] 120,320 | | └─_DenseLayer: 3-40 [-1, 32, 2, 2] 124,480 | | └─_DenseLayer: 3-41 [-1, 32, 2, 2] 128,640 | | └─_DenseLayer: 3-42 [-1, 32, 2, 2] 132,800 | | └─_DenseLayer: 3-43 [-1, 32, 2, 2] 136,960 | | └─_DenseLayer: 3-44 [-1, 32, 2, 2] 141,120 | | └─_DenseLayer: 3-45 [-1, 32, 2, 2] 145,280 | | └─_DenseLayer: 3-46 [-1, 32, 2, 2] 149,440 | | └─_DenseLayer: 3-47 [-1, 32, 2, 2] 153,600 | | └─_DenseLayer: 3-48 [-1, 32, 2, 2] 157,760 | | └─_DenseLayer: 3-49 [-1, 32, 2, 2] 161,920 | | └─_DenseLayer: 3-50 [-1, 32, 2, 2] 166,080 | | └─_DenseLayer: 3-51 [-1, 32, 2, 2] 170,240 | | └─_DenseLayer: 3-52 [-1, 32, 2, 2] 174,400 | | └─_DenseLayer: 3-53 [-1, 32, 2, 2] 178,560 | | └─_DenseLayer: 3-54 [-1, 32, 2, 2] 182,720 | | └─_DenseLayer: 3-55 [-1, 32, 2, 2] 186,880 | | └─_DenseLayer: 3-56 [-1, 32, 2, 2] 191,040 | | └─_DenseLayer: 3-57 [-1, 32, 2, 2] 195,200 | | └─_DenseLayer: 3-58 [-1, 32, 2, 2] 199,360 | | └─_DenseLayer: 3-59 [-1, 32, 2, 2] 203,520 | | └─_DenseLayer: 3-60 [-1, 32, 2, 2] 207,680 | | └─_DenseLayer: 3-61 [-1, 32, 2, 2] 211,840 | | └─_DenseLayer: 3-62 [-1, 32, 2, 2] 216,000 | | └─_DenseLayer: 3-63 [-1, 32, 2, 2] 220,160 | | └─_DenseLayer: 3-64 [-1, 32, 2, 2] 224,320 | | └─_DenseLayer: 3-65 [-1, 32, 2, 2] 228,480 | | └─_DenseLayer: 3-66 [-1, 32, 2, 2] 232,640 | | └─_DenseLayer: 3-67 [-1, 32, 2, 2] 236,800 | | └─_DenseLayer: 3-68 [-1, 32, 2, 2] 240,960 | | └─_DenseLayer: 3-69 [-1, 32, 2, 2] 245,120 | | └─_DenseLayer: 3-70 [-1, 32, 2, 2] 249,280 | | └─_DenseLayer: 3-71 [-1, 32, 2, 2] 253,440 | | └─_DenseLayer: 3-72 [-1, 32, 2, 2] 257,600 | | └─_DenseLayer: 3-73 [-1, 32, 2, 2] 261,760 | | └─_DenseLayer: 3-74 [-1, 32, 2, 2] 265,920 | └─_Transition: 2-10 [-1, 896, 1, 1] -- | | └─BatchNorm2d: 3-75 [-1, 1792, 2, 2] 3,584 | | └─ReLU: 3-76 [-1, 1792, 2, 2] -- | | └─Conv2d: 3-77 [-1, 896, 2, 2] 1,605,632 | | └─AvgPool2d: 3-78 [-1, 896, 1, 1] -- | └─_DenseBlock: 2-11 [-1, 1920, 1, 1] -- | | └─_DenseLayer: 3-79 [-1, 32, 1, 1] 153,600 | | └─_DenseLayer: 3-80 [-1, 32, 1, 1] 157,760 | | └─_DenseLayer: 3-81 [-1, 32, 1, 1] 161,920 | | └─_DenseLayer: 3-82 [-1, 32, 1, 1] 166,080 | | └─_DenseLayer: 3-83 [-1, 32, 1, 1] 170,240 | | └─_DenseLayer: 3-84 [-1, 32, 1, 1] 174,400 | | └─_DenseLayer: 3-85 [-1, 32, 1, 1] 178,560 | | └─_DenseLayer: 3-86 [-1, 32, 1, 1] 182,720 | | └─_DenseLayer: 3-87 [-1, 32, 1, 1] 186,880 | | └─_DenseLayer: 3-88 [-1, 32, 1, 1] 191,040 | | └─_DenseLayer: 3-89 [-1, 32, 1, 1] 195,200 | | └─_DenseLayer: 3-90 [-1, 32, 1, 1] 199,360 | | └─_DenseLayer: 3-91 [-1, 32, 1, 1] 203,520 | | └─_DenseLayer: 3-92 [-1, 32, 1, 1] 207,680 | | └─_DenseLayer: 3-93 [-1, 32, 1, 1] 211,840 | | └─_DenseLayer: 3-94 [-1, 32, 1, 1] 216,000 | | └─_DenseLayer: 3-95 [-1, 32, 1, 1] 220,160 | | └─_DenseLayer: 3-96 [-1, 32, 1, 1] 224,320 | | └─_DenseLayer: 3-97 [-1, 32, 1, 1] 228,480 | | └─_DenseLayer: 3-98 [-1, 32, 1, 1] 232,640 | | └─_DenseLayer: 3-99 [-1, 32, 1, 1] 236,800 | | └─_DenseLayer: 3-100 [-1, 32, 1, 1] 240,960 | | └─_DenseLayer: 3-101 [-1, 32, 1, 1] 245,120 | | └─_DenseLayer: 3-102 [-1, 32, 1, 1] 249,280 | | └─_DenseLayer: 3-103 [-1, 32, 1, 1] 253,440 | | └─_DenseLayer: 3-104 [-1, 32, 1, 1] 257,600 | | └─_DenseLayer: 3-105 [-1, 32, 1, 1] 261,760 | | └─_DenseLayer: 3-106 [-1, 32, 1, 1] 265,920 | | └─_DenseLayer: 3-107 [-1, 32, 1, 1] 270,080 | | └─_DenseLayer: 3-108 [-1, 32, 1, 1] 274,240 | | └─_DenseLayer: 3-109 [-1, 32, 1, 1] 278,400 | | └─_DenseLayer: 3-110 [-1, 32, 1, 1] 282,560 | └─BatchNorm2d: 2-12 [-1, 1920, 1, 1] 3,840 ├─Linear: 1-2 [-1, 100] 192,100 ========================================================================================== Total params: 18,285,028 Trainable params: 18,285,028 Non-trainable params: 0 Total mult-adds (M): 139.99 ========================================================================================== Input size (MB): 0.01 Forward/backward pass size (MB): 5.07 Params size (MB): 69.75 Estimated Total Size (MB): 74.83 ==========================================================================================
Initializing models for coarsed labeled dataset
# Baseline VGG inspired model
print('\nSimpleNet Coarse')
BaselineCoarse = SimpleNet([2,3,4,4],3,32,2,2048,20)
summary(Baseline,(3,32,32))
# ResNet architectures
ResNet34Coarse = ResNet(Bottleneck,[3,4,6,3], num_classes=20) #
print('\nResNet34 Coarse - Medium')
summary(ResNet34,(3,32,32))
ResNeXtCustom1Coarse = ResNet(Bottleneck,[3,4,6,3] ,groups=32,width_per_group=4, num_classes=20) #
print('\nResNeXt Custom 2 Coarse -> ResNet34 + cardinality(32), 4 "width" per cardinality - High Complexity')
summary(ResNeXtCustom1,(3,32,32))
# Attention architectures
CoAtNetCustom1Coarse = CoAtNet((32, 32), 3, [2, 2, 6, 14, 2], [64, 96, 192, 384, 768], num_classes=20)
print('\nCoAtNet Simple Coarse')
summary(CoAtNetCustom1,(3,32,32))
# DenseNet architectures
DenseNetCustom1Coarse = DenseNet(32, (6, 12, 48, 32), 64, num_classes=20)
print('\nDenseNet Coarse - Complex')
summary(DenseNetCustom1,(3,32,32))
SimpleNet Coarse ========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== ├─Sequential: 1-1 [-1, 32, 30, 30] -- | └─Conv2d: 2-1 [-1, 32, 30, 30] 2,432 | └─BatchNorm2d: 2-2 [-1, 32, 30, 30] 64 | └─ReLU: 2-3 [-1, 32, 30, 30] -- ├─Sequential: 1-2 [-1, 32, 15, 15] -- | └─Conv2d: 2-4 [-1, 32, 30, 30] 9,248 | └─BatchNorm2d: 2-5 [-1, 32, 30, 30] 64 | └─ReLU: 2-6 [-1, 32, 30, 30] -- | └─MaxPool2d: 2-7 [-1, 32, 15, 15] -- ├─Sequential: 1-3 [-1, 64, 15, 15] -- | └─Conv2d: 2-8 [-1, 64, 15, 15] 18,496 | └─BatchNorm2d: 2-9 [-1, 64, 15, 15] 128 | └─ReLU: 2-10 [-1, 64, 15, 15] -- ├─Sequential: 1-4 [-1, 64, 7, 7] -- | └─Conv2d: 2-11 [-1, 64, 15, 15] 36,928 | └─BatchNorm2d: 2-12 [-1, 64, 15, 15] 128 | └─ReLU: 2-13 [-1, 64, 15, 15] -- | └─Conv2d: 2-14 [-1, 64, 15, 15] 36,928 | └─BatchNorm2d: 2-15 [-1, 64, 15, 15] 128 | └─ReLU: 2-16 [-1, 64, 15, 15] -- | └─MaxPool2d: 2-17 [-1, 64, 7, 7] -- ├─Sequential: 1-5 [-1, 128, 7, 7] -- | └─Conv2d: 2-18 [-1, 128, 7, 7] 73,856 | └─BatchNorm2d: 2-19 [-1, 128, 7, 7] 256 | └─ReLU: 2-20 [-1, 128, 7, 7] -- ├─Sequential: 1-6 [-1, 128, 3, 3] -- | └─Conv2d: 2-21 [-1, 128, 7, 7] 147,584 | └─BatchNorm2d: 2-22 [-1, 128, 7, 7] 256 | └─ReLU: 2-23 [-1, 128, 7, 7] -- | └─Conv2d: 2-24 [-1, 128, 7, 7] 147,584 | └─BatchNorm2d: 2-25 [-1, 128, 7, 7] 256 | └─ReLU: 2-26 [-1, 128, 7, 7] -- | └─Conv2d: 2-27 [-1, 128, 7, 7] 147,584 | └─BatchNorm2d: 2-28 [-1, 128, 7, 7] 256 | └─ReLU: 2-29 [-1, 128, 7, 7] -- | └─MaxPool2d: 2-30 [-1, 128, 3, 3] -- ├─Sequential: 1-7 [-1, 256, 3, 3] -- | └─Conv2d: 2-31 [-1, 256, 3, 3] 295,168 | └─BatchNorm2d: 2-32 [-1, 256, 3, 3] 512 | └─ReLU: 2-33 [-1, 256, 3, 3] -- ├─Sequential: 1-8 [-1, 256, 2, 2] -- | └─Conv2d: 2-34 [-1, 256, 3, 3] 590,080 | └─BatchNorm2d: 2-35 [-1, 256, 3, 3] 512 | └─ReLU: 2-36 [-1, 256, 3, 3] -- | └─Conv2d: 2-37 [-1, 256, 3, 3] 590,080 | └─BatchNorm2d: 2-38 [-1, 256, 3, 3] 512 | └─ReLU: 2-39 [-1, 256, 3, 3] -- | └─Conv2d: 2-40 [-1, 256, 3, 3] 590,080 | └─BatchNorm2d: 2-41 [-1, 256, 3, 3] 512 | └─ReLU: 2-42 [-1, 256, 3, 3] -- | └─MaxPool2d: 2-43 [-1, 256, 2, 2] -- ├─Flatten: 1-9 [-1, 1024] -- ├─Sequential: 1-10 [-1, 100] -- | └─Linear: 2-44 [-1, 2048] 2,099,200 | └─BatchNorm1d: 2-45 [-1, 2048] 4,096 | └─ReLU: 2-46 [-1, 2048] -- | └─Dropout: 2-47 [-1, 2048] -- | └─Linear: 2-48 [-1, 2048] 4,196,352 | └─BatchNorm1d: 2-49 [-1, 2048] 4,096 | └─ReLU: 2-50 [-1, 2048] -- | └─Linear: 2-51 [-1, 100] 204,900 ========================================================================================== Total params: 9,198,276 Trainable params: 9,198,276 Non-trainable params: 0 Total mult-adds (M): 90.75 ========================================================================================== Input size (MB): 0.01 Forward/backward pass size (MB): 2.12 Params size (MB): 35.09 Estimated Total Size (MB): 37.23 ========================================================================================== ResNet34 Coarse - Medium ========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== ├─Conv2d: 1-1 [-1, 64, 16, 16] 9,408 ├─BatchNorm2d: 1-2 [-1, 64, 16, 16] 128 ├─ReLU: 1-3 [-1, 64, 16, 16] -- ├─MaxPool2d: 1-4 [-1, 64, 8, 8] -- ├─Sequential: 1-5 [-1, 256, 8, 8] -- | └─Bottleneck: 2-1 [-1, 256, 8, 8] -- | | └─Conv2d: 3-1 [-1, 64, 8, 8] 4,096 | | └─BatchNorm2d: 3-2 [-1, 64, 8, 8] 128 | | └─ReLU: 3-3 [-1, 64, 8, 8] -- | | └─Conv2d: 3-4 [-1, 64, 8, 8] 36,864 | | └─BatchNorm2d: 3-5 [-1, 64, 8, 8] 128 | | └─ReLU: 3-6 [-1, 64, 8, 8] -- | | └─Conv2d: 3-7 [-1, 256, 8, 8] 16,384 | | └─BatchNorm2d: 3-8 [-1, 256, 8, 8] 512 | | └─Sequential: 3-9 [-1, 256, 8, 8] 16,896 | | └─ReLU: 3-10 [-1, 256, 8, 8] -- | └─Bottleneck: 2-2 [-1, 256, 8, 8] -- | | └─Conv2d: 3-11 [-1, 64, 8, 8] 16,384 | | └─BatchNorm2d: 3-12 [-1, 64, 8, 8] 128 | | └─ReLU: 3-13 [-1, 64, 8, 8] -- | | └─Conv2d: 3-14 [-1, 64, 8, 8] 36,864 | | └─BatchNorm2d: 3-15 [-1, 64, 8, 8] 128 | | └─ReLU: 3-16 [-1, 64, 8, 8] -- | | └─Conv2d: 3-17 [-1, 256, 8, 8] 16,384 | | └─BatchNorm2d: 3-18 [-1, 256, 8, 8] 512 | | └─ReLU: 3-19 [-1, 256, 8, 8] -- | └─Bottleneck: 2-3 [-1, 256, 8, 8] -- | | └─Conv2d: 3-20 [-1, 64, 8, 8] 16,384 | | └─BatchNorm2d: 3-21 [-1, 64, 8, 8] 128 | | └─ReLU: 3-22 [-1, 64, 8, 8] -- | | └─Conv2d: 3-23 [-1, 64, 8, 8] 36,864 | | └─BatchNorm2d: 3-24 [-1, 64, 8, 8] 128 | | └─ReLU: 3-25 [-1, 64, 8, 8] -- | | └─Conv2d: 3-26 [-1, 256, 8, 8] 16,384 | | └─BatchNorm2d: 3-27 [-1, 256, 8, 8] 512 | | └─ReLU: 3-28 [-1, 256, 8, 8] -- ├─Sequential: 1-6 [-1, 512, 4, 4] -- | └─Bottleneck: 2-4 [-1, 512, 4, 4] -- | | └─Conv2d: 3-29 [-1, 128, 8, 8] 32,768 | | └─BatchNorm2d: 3-30 [-1, 128, 8, 8] 256 | | └─ReLU: 3-31 [-1, 128, 8, 8] -- | | └─Conv2d: 3-32 [-1, 128, 4, 4] 147,456 | | └─BatchNorm2d: 3-33 [-1, 128, 4, 4] 256 | | └─ReLU: 3-34 [-1, 128, 4, 4] -- | | └─Conv2d: 3-35 [-1, 512, 4, 4] 65,536 | | └─BatchNorm2d: 3-36 [-1, 512, 4, 4] 1,024 | | └─Sequential: 3-37 [-1, 512, 4, 4] 132,096 | | └─ReLU: 3-38 [-1, 512, 4, 4] -- | └─Bottleneck: 2-5 [-1, 512, 4, 4] -- | | └─Conv2d: 3-39 [-1, 128, 4, 4] 65,536 | | └─BatchNorm2d: 3-40 [-1, 128, 4, 4] 256 | | └─ReLU: 3-41 [-1, 128, 4, 4] -- | | └─Conv2d: 3-42 [-1, 128, 4, 4] 147,456 | | └─BatchNorm2d: 3-43 [-1, 128, 4, 4] 256 | | └─ReLU: 3-44 [-1, 128, 4, 4] -- | | └─Conv2d: 3-45 [-1, 512, 4, 4] 65,536 | | └─BatchNorm2d: 3-46 [-1, 512, 4, 4] 1,024 | | └─ReLU: 3-47 [-1, 512, 4, 4] -- | └─Bottleneck: 2-6 [-1, 512, 4, 4] -- | | └─Conv2d: 3-48 [-1, 128, 4, 4] 65,536 | | └─BatchNorm2d: 3-49 [-1, 128, 4, 4] 256 | | └─ReLU: 3-50 [-1, 128, 4, 4] -- | | └─Conv2d: 3-51 [-1, 128, 4, 4] 147,456 | | └─BatchNorm2d: 3-52 [-1, 128, 4, 4] 256 | | └─ReLU: 3-53 [-1, 128, 4, 4] -- | | └─Conv2d: 3-54 [-1, 512, 4, 4] 65,536 | | └─BatchNorm2d: 3-55 [-1, 512, 4, 4] 1,024 | | └─ReLU: 3-56 [-1, 512, 4, 4] -- | └─Bottleneck: 2-7 [-1, 512, 4, 4] -- | | └─Conv2d: 3-57 [-1, 128, 4, 4] 65,536 | | └─BatchNorm2d: 3-58 [-1, 128, 4, 4] 256 | | └─ReLU: 3-59 [-1, 128, 4, 4] -- | | └─Conv2d: 3-60 [-1, 128, 4, 4] 147,456 | | └─BatchNorm2d: 3-61 [-1, 128, 4, 4] 256 | | └─ReLU: 3-62 [-1, 128, 4, 4] -- | | └─Conv2d: 3-63 [-1, 512, 4, 4] 65,536 | | └─BatchNorm2d: 3-64 [-1, 512, 4, 4] 1,024 | | └─ReLU: 3-65 [-1, 512, 4, 4] -- ├─Sequential: 1-7 [-1, 1024, 2, 2] -- | └─Bottleneck: 2-8 [-1, 1024, 2, 2] -- | | └─Conv2d: 3-66 [-1, 256, 4, 4] 131,072 | | └─BatchNorm2d: 3-67 [-1, 256, 4, 4] 512 | | └─ReLU: 3-68 [-1, 256, 4, 4] -- | | └─Conv2d: 3-69 [-1, 256, 2, 2] 589,824 | | └─BatchNorm2d: 3-70 [-1, 256, 2, 2] 512 | | └─ReLU: 3-71 [-1, 256, 2, 2] -- | | └─Conv2d: 3-72 [-1, 1024, 2, 2] 262,144 | | └─BatchNorm2d: 3-73 [-1, 1024, 2, 2] 2,048 | | └─Sequential: 3-74 [-1, 1024, 2, 2] 526,336 | | └─ReLU: 3-75 [-1, 1024, 2, 2] -- | └─Bottleneck: 2-9 [-1, 1024, 2, 2] -- | | └─Conv2d: 3-76 [-1, 256, 2, 2] 262,144 | | └─BatchNorm2d: 3-77 [-1, 256, 2, 2] 512 | | └─ReLU: 3-78 [-1, 256, 2, 2] -- | | └─Conv2d: 3-79 [-1, 256, 2, 2] 589,824 | | └─BatchNorm2d: 3-80 [-1, 256, 2, 2] 512 | | └─ReLU: 3-81 [-1, 256, 2, 2] -- | | └─Conv2d: 3-82 [-1, 1024, 2, 2] 262,144 | | └─BatchNorm2d: 3-83 [-1, 1024, 2, 2] 2,048 | | └─ReLU: 3-84 [-1, 1024, 2, 2] -- | └─Bottleneck: 2-10 [-1, 1024, 2, 2] -- | | └─Conv2d: 3-85 [-1, 256, 2, 2] 262,144 | | └─BatchNorm2d: 3-86 [-1, 256, 2, 2] 512 | | └─ReLU: 3-87 [-1, 256, 2, 2] -- | | └─Conv2d: 3-88 [-1, 256, 2, 2] 589,824 | | └─BatchNorm2d: 3-89 [-1, 256, 2, 2] 512 | | └─ReLU: 3-90 [-1, 256, 2, 2] -- | | └─Conv2d: 3-91 [-1, 1024, 2, 2] 262,144 | | └─BatchNorm2d: 3-92 [-1, 1024, 2, 2] 2,048 | | └─ReLU: 3-93 [-1, 1024, 2, 2] -- | └─Bottleneck: 2-11 [-1, 1024, 2, 2] -- | | └─Conv2d: 3-94 [-1, 256, 2, 2] 262,144 | | └─BatchNorm2d: 3-95 [-1, 256, 2, 2] 512 | | └─ReLU: 3-96 [-1, 256, 2, 2] -- | | └─Conv2d: 3-97 [-1, 256, 2, 2] 589,824 | | └─BatchNorm2d: 3-98 [-1, 256, 2, 2] 512 | | └─ReLU: 3-99 [-1, 256, 2, 2] -- | | └─Conv2d: 3-100 [-1, 1024, 2, 2] 262,144 | | └─BatchNorm2d: 3-101 [-1, 1024, 2, 2] 2,048 | | └─ReLU: 3-102 [-1, 1024, 2, 2] -- | └─Bottleneck: 2-12 [-1, 1024, 2, 2] -- | | └─Conv2d: 3-103 [-1, 256, 2, 2] 262,144 | | └─BatchNorm2d: 3-104 [-1, 256, 2, 2] 512 | | └─ReLU: 3-105 [-1, 256, 2, 2] -- | | └─Conv2d: 3-106 [-1, 256, 2, 2] 589,824 | | └─BatchNorm2d: 3-107 [-1, 256, 2, 2] 512 | | └─ReLU: 3-108 [-1, 256, 2, 2] -- | | └─Conv2d: 3-109 [-1, 1024, 2, 2] 262,144 | | └─BatchNorm2d: 3-110 [-1, 1024, 2, 2] 2,048 | | └─ReLU: 3-111 [-1, 1024, 2, 2] -- | └─Bottleneck: 2-13 [-1, 1024, 2, 2] -- | | └─Conv2d: 3-112 [-1, 256, 2, 2] 262,144 | | └─BatchNorm2d: 3-113 [-1, 256, 2, 2] 512 | | └─ReLU: 3-114 [-1, 256, 2, 2] -- | | └─Conv2d: 3-115 [-1, 256, 2, 2] 589,824 | | └─BatchNorm2d: 3-116 [-1, 256, 2, 2] 512 | | └─ReLU: 3-117 [-1, 256, 2, 2] -- | | └─Conv2d: 3-118 [-1, 1024, 2, 2] 262,144 | | └─BatchNorm2d: 3-119 [-1, 1024, 2, 2] 2,048 | | └─ReLU: 3-120 [-1, 1024, 2, 2] -- ├─Sequential: 1-8 [-1, 2048, 1, 1] -- | └─Bottleneck: 2-14 [-1, 2048, 1, 1] -- | | └─Conv2d: 3-121 [-1, 512, 2, 2] 524,288 | | └─BatchNorm2d: 3-122 [-1, 512, 2, 2] 1,024 | | └─ReLU: 3-123 [-1, 512, 2, 2] -- | | └─Conv2d: 3-124 [-1, 512, 1, 1] 2,359,296 | | └─BatchNorm2d: 3-125 [-1, 512, 1, 1] 1,024 | | └─ReLU: 3-126 [-1, 512, 1, 1] -- | | └─Conv2d: 3-127 [-1, 2048, 1, 1] 1,048,576 | | └─BatchNorm2d: 3-128 [-1, 2048, 1, 1] 4,096 | | └─Sequential: 3-129 [-1, 2048, 1, 1] 2,101,248 | | └─ReLU: 3-130 [-1, 2048, 1, 1] -- | └─Bottleneck: 2-15 [-1, 2048, 1, 1] -- | | └─Conv2d: 3-131 [-1, 512, 1, 1] 1,048,576 | | └─BatchNorm2d: 3-132 [-1, 512, 1, 1] 1,024 | | └─ReLU: 3-133 [-1, 512, 1, 1] -- | | └─Conv2d: 3-134 [-1, 512, 1, 1] 2,359,296 | | └─BatchNorm2d: 3-135 [-1, 512, 1, 1] 1,024 | | └─ReLU: 3-136 [-1, 512, 1, 1] -- | | └─Conv2d: 3-137 [-1, 2048, 1, 1] 1,048,576 | | └─BatchNorm2d: 3-138 [-1, 2048, 1, 1] 4,096 | | └─ReLU: 3-139 [-1, 2048, 1, 1] -- | └─Bottleneck: 2-16 [-1, 2048, 1, 1] -- | | └─Conv2d: 3-140 [-1, 512, 1, 1] 1,048,576 | | └─BatchNorm2d: 3-141 [-1, 512, 1, 1] 1,024 | | └─ReLU: 3-142 [-1, 512, 1, 1] -- | | └─Conv2d: 3-143 [-1, 512, 1, 1] 2,359,296 | | └─BatchNorm2d: 3-144 [-1, 512, 1, 1] 1,024 | | └─ReLU: 3-145 [-1, 512, 1, 1] -- | | └─Conv2d: 3-146 [-1, 2048, 1, 1] 1,048,576 | | └─BatchNorm2d: 3-147 [-1, 2048, 1, 1] 4,096 | | └─ReLU: 3-148 [-1, 2048, 1, 1] -- ├─AdaptiveAvgPool2d: 1-9 [-1, 2048, 1, 1] -- ├─Linear: 1-10 [-1, 100] 204,900 ========================================================================================== Total params: 23,712,932 Trainable params: 23,712,932 Non-trainable params: 0 Total mult-adds (M): 133.36 ========================================================================================== Input size (MB): 0.01 Forward/backward pass size (MB): 3.46 Params size (MB): 90.46 Estimated Total Size (MB): 93.93 ========================================================================================== ResNeXt Custom 2 Coarse -> ResNet34 + cardinality(32), 4 "width" per cardinality - High Complexity ========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== ├─Conv2d: 1-1 [-1, 64, 16, 16] 9,408 ├─BatchNorm2d: 1-2 [-1, 64, 16, 16] 128 ├─ReLU: 1-3 [-1, 64, 16, 16] -- ├─MaxPool2d: 1-4 [-1, 64, 8, 8] -- ├─Sequential: 1-5 [-1, 256, 8, 8] -- | └─Bottleneck: 2-1 [-1, 256, 8, 8] -- | | └─Conv2d: 3-1 [-1, 128, 8, 8] 8,192 | | └─BatchNorm2d: 3-2 [-1, 128, 8, 8] 256 | | └─ReLU: 3-3 [-1, 128, 8, 8] -- | | └─Conv2d: 3-4 [-1, 128, 8, 8] 4,608 | | └─BatchNorm2d: 3-5 [-1, 128, 8, 8] 256 | | └─ReLU: 3-6 [-1, 128, 8, 8] -- | | └─Conv2d: 3-7 [-1, 256, 8, 8] 32,768 | | └─BatchNorm2d: 3-8 [-1, 256, 8, 8] 512 | | └─Sequential: 3-9 [-1, 256, 8, 8] 16,896 | | └─ReLU: 3-10 [-1, 256, 8, 8] -- | └─Bottleneck: 2-2 [-1, 256, 8, 8] -- | | └─Conv2d: 3-11 [-1, 128, 8, 8] 32,768 | | └─BatchNorm2d: 3-12 [-1, 128, 8, 8] 256 | | └─ReLU: 3-13 [-1, 128, 8, 8] -- | | └─Conv2d: 3-14 [-1, 128, 8, 8] 4,608 | | └─BatchNorm2d: 3-15 [-1, 128, 8, 8] 256 | | └─ReLU: 3-16 [-1, 128, 8, 8] -- | | └─Conv2d: 3-17 [-1, 256, 8, 8] 32,768 | | └─BatchNorm2d: 3-18 [-1, 256, 8, 8] 512 | | └─ReLU: 3-19 [-1, 256, 8, 8] -- | └─Bottleneck: 2-3 [-1, 256, 8, 8] -- | | └─Conv2d: 3-20 [-1, 128, 8, 8] 32,768 | | └─BatchNorm2d: 3-21 [-1, 128, 8, 8] 256 | | └─ReLU: 3-22 [-1, 128, 8, 8] -- | | └─Conv2d: 3-23 [-1, 128, 8, 8] 4,608 | | └─BatchNorm2d: 3-24 [-1, 128, 8, 8] 256 | | └─ReLU: 3-25 [-1, 128, 8, 8] -- | | └─Conv2d: 3-26 [-1, 256, 8, 8] 32,768 | | └─BatchNorm2d: 3-27 [-1, 256, 8, 8] 512 | | └─ReLU: 3-28 [-1, 256, 8, 8] -- ├─Sequential: 1-6 [-1, 512, 4, 4] -- | └─Bottleneck: 2-4 [-1, 512, 4, 4] -- | | └─Conv2d: 3-29 [-1, 256, 8, 8] 65,536 | | └─BatchNorm2d: 3-30 [-1, 256, 8, 8] 512 | | └─ReLU: 3-31 [-1, 256, 8, 8] -- | | └─Conv2d: 3-32 [-1, 256, 4, 4] 18,432 | | └─BatchNorm2d: 3-33 [-1, 256, 4, 4] 512 | | └─ReLU: 3-34 [-1, 256, 4, 4] -- | | └─Conv2d: 3-35 [-1, 512, 4, 4] 131,072 | | └─BatchNorm2d: 3-36 [-1, 512, 4, 4] 1,024 | | └─Sequential: 3-37 [-1, 512, 4, 4] 132,096 | | └─ReLU: 3-38 [-1, 512, 4, 4] -- | └─Bottleneck: 2-5 [-1, 512, 4, 4] -- | | └─Conv2d: 3-39 [-1, 256, 4, 4] 131,072 | | └─BatchNorm2d: 3-40 [-1, 256, 4, 4] 512 | | └─ReLU: 3-41 [-1, 256, 4, 4] -- | | └─Conv2d: 3-42 [-1, 256, 4, 4] 18,432 | | └─BatchNorm2d: 3-43 [-1, 256, 4, 4] 512 | | └─ReLU: 3-44 [-1, 256, 4, 4] -- | | └─Conv2d: 3-45 [-1, 512, 4, 4] 131,072 | | └─BatchNorm2d: 3-46 [-1, 512, 4, 4] 1,024 | | └─ReLU: 3-47 [-1, 512, 4, 4] -- | └─Bottleneck: 2-6 [-1, 512, 4, 4] -- | | └─Conv2d: 3-48 [-1, 256, 4, 4] 131,072 | | └─BatchNorm2d: 3-49 [-1, 256, 4, 4] 512 | | └─ReLU: 3-50 [-1, 256, 4, 4] -- | | └─Conv2d: 3-51 [-1, 256, 4, 4] 18,432 | | └─BatchNorm2d: 3-52 [-1, 256, 4, 4] 512 | | └─ReLU: 3-53 [-1, 256, 4, 4] -- | | └─Conv2d: 3-54 [-1, 512, 4, 4] 131,072 | | └─BatchNorm2d: 3-55 [-1, 512, 4, 4] 1,024 | | └─ReLU: 3-56 [-1, 512, 4, 4] -- | └─Bottleneck: 2-7 [-1, 512, 4, 4] -- | | └─Conv2d: 3-57 [-1, 256, 4, 4] 131,072 | | └─BatchNorm2d: 3-58 [-1, 256, 4, 4] 512 | | └─ReLU: 3-59 [-1, 256, 4, 4] -- | | └─Conv2d: 3-60 [-1, 256, 4, 4] 18,432 | | └─BatchNorm2d: 3-61 [-1, 256, 4, 4] 512 | | └─ReLU: 3-62 [-1, 256, 4, 4] -- | | └─Conv2d: 3-63 [-1, 512, 4, 4] 131,072 | | └─BatchNorm2d: 3-64 [-1, 512, 4, 4] 1,024 | | └─ReLU: 3-65 [-1, 512, 4, 4] -- ├─Sequential: 1-7 [-1, 1024, 2, 2] -- | └─Bottleneck: 2-8 [-1, 1024, 2, 2] -- | | └─Conv2d: 3-66 [-1, 512, 4, 4] 262,144 | | └─BatchNorm2d: 3-67 [-1, 512, 4, 4] 1,024 | | └─ReLU: 3-68 [-1, 512, 4, 4] -- | | └─Conv2d: 3-69 [-1, 512, 2, 2] 73,728 | | └─BatchNorm2d: 3-70 [-1, 512, 2, 2] 1,024 | | └─ReLU: 3-71 [-1, 512, 2, 2] -- | | └─Conv2d: 3-72 [-1, 1024, 2, 2] 524,288 | | └─BatchNorm2d: 3-73 [-1, 1024, 2, 2] 2,048 | | └─Sequential: 3-74 [-1, 1024, 2, 2] 526,336 | | └─ReLU: 3-75 [-1, 1024, 2, 2] -- | └─Bottleneck: 2-9 [-1, 1024, 2, 2] -- | | └─Conv2d: 3-76 [-1, 512, 2, 2] 524,288 | | └─BatchNorm2d: 3-77 [-1, 512, 2, 2] 1,024 | | └─ReLU: 3-78 [-1, 512, 2, 2] -- | | └─Conv2d: 3-79 [-1, 512, 2, 2] 73,728 | | └─BatchNorm2d: 3-80 [-1, 512, 2, 2] 1,024 | | └─ReLU: 3-81 [-1, 512, 2, 2] -- | | └─Conv2d: 3-82 [-1, 1024, 2, 2] 524,288 | | └─BatchNorm2d: 3-83 [-1, 1024, 2, 2] 2,048 | | └─ReLU: 3-84 [-1, 1024, 2, 2] -- | └─Bottleneck: 2-10 [-1, 1024, 2, 2] -- | | └─Conv2d: 3-85 [-1, 512, 2, 2] 524,288 | | └─BatchNorm2d: 3-86 [-1, 512, 2, 2] 1,024 | | └─ReLU: 3-87 [-1, 512, 2, 2] -- | | └─Conv2d: 3-88 [-1, 512, 2, 2] 73,728 | | └─BatchNorm2d: 3-89 [-1, 512, 2, 2] 1,024 | | └─ReLU: 3-90 [-1, 512, 2, 2] -- | | └─Conv2d: 3-91 [-1, 1024, 2, 2] 524,288 | | └─BatchNorm2d: 3-92 [-1, 1024, 2, 2] 2,048 | | └─ReLU: 3-93 [-1, 1024, 2, 2] -- | └─Bottleneck: 2-11 [-1, 1024, 2, 2] -- | | └─Conv2d: 3-94 [-1, 512, 2, 2] 524,288 | | └─BatchNorm2d: 3-95 [-1, 512, 2, 2] 1,024 | | └─ReLU: 3-96 [-1, 512, 2, 2] -- | | └─Conv2d: 3-97 [-1, 512, 2, 2] 73,728 | | └─BatchNorm2d: 3-98 [-1, 512, 2, 2] 1,024 | | └─ReLU: 3-99 [-1, 512, 2, 2] -- | | └─Conv2d: 3-100 [-1, 1024, 2, 2] 524,288 | | └─BatchNorm2d: 3-101 [-1, 1024, 2, 2] 2,048 | | └─ReLU: 3-102 [-1, 1024, 2, 2] -- | └─Bottleneck: 2-12 [-1, 1024, 2, 2] -- | | └─Conv2d: 3-103 [-1, 512, 2, 2] 524,288 | | └─BatchNorm2d: 3-104 [-1, 512, 2, 2] 1,024 | | └─ReLU: 3-105 [-1, 512, 2, 2] -- | | └─Conv2d: 3-106 [-1, 512, 2, 2] 73,728 | | └─BatchNorm2d: 3-107 [-1, 512, 2, 2] 1,024 | | └─ReLU: 3-108 [-1, 512, 2, 2] -- | | └─Conv2d: 3-109 [-1, 1024, 2, 2] 524,288 | | └─BatchNorm2d: 3-110 [-1, 1024, 2, 2] 2,048 | | └─ReLU: 3-111 [-1, 1024, 2, 2] -- | └─Bottleneck: 2-13 [-1, 1024, 2, 2] -- | | └─Conv2d: 3-112 [-1, 512, 2, 2] 524,288 | | └─BatchNorm2d: 3-113 [-1, 512, 2, 2] 1,024 | | └─ReLU: 3-114 [-1, 512, 2, 2] -- | | └─Conv2d: 3-115 [-1, 512, 2, 2] 73,728 | | └─BatchNorm2d: 3-116 [-1, 512, 2, 2] 1,024 | | └─ReLU: 3-117 [-1, 512, 2, 2] -- | | └─Conv2d: 3-118 [-1, 1024, 2, 2] 524,288 | | └─BatchNorm2d: 3-119 [-1, 1024, 2, 2] 2,048 | | └─ReLU: 3-120 [-1, 1024, 2, 2] -- ├─Sequential: 1-8 [-1, 2048, 1, 1] -- | └─Bottleneck: 2-14 [-1, 2048, 1, 1] -- | | └─Conv2d: 3-121 [-1, 1024, 2, 2] 1,048,576 | | └─BatchNorm2d: 3-122 [-1, 1024, 2, 2] 2,048 | | └─ReLU: 3-123 [-1, 1024, 2, 2] -- | | └─Conv2d: 3-124 [-1, 1024, 1, 1] 294,912 | | └─BatchNorm2d: 3-125 [-1, 1024, 1, 1] 2,048 | | └─ReLU: 3-126 [-1, 1024, 1, 1] -- | | └─Conv2d: 3-127 [-1, 2048, 1, 1] 2,097,152 | | └─BatchNorm2d: 3-128 [-1, 2048, 1, 1] 4,096 | | └─Sequential: 3-129 [-1, 2048, 1, 1] 2,101,248 | | └─ReLU: 3-130 [-1, 2048, 1, 1] -- | └─Bottleneck: 2-15 [-1, 2048, 1, 1] -- | | └─Conv2d: 3-131 [-1, 1024, 1, 1] 2,097,152 | | └─BatchNorm2d: 3-132 [-1, 1024, 1, 1] 2,048 | | └─ReLU: 3-133 [-1, 1024, 1, 1] -- | | └─Conv2d: 3-134 [-1, 1024, 1, 1] 294,912 | | └─BatchNorm2d: 3-135 [-1, 1024, 1, 1] 2,048 | | └─ReLU: 3-136 [-1, 1024, 1, 1] -- | | └─Conv2d: 3-137 [-1, 2048, 1, 1] 2,097,152 | | └─BatchNorm2d: 3-138 [-1, 2048, 1, 1] 4,096 | | └─ReLU: 3-139 [-1, 2048, 1, 1] -- | └─Bottleneck: 2-16 [-1, 2048, 1, 1] -- | | └─Conv2d: 3-140 [-1, 1024, 1, 1] 2,097,152 | | └─BatchNorm2d: 3-141 [-1, 1024, 1, 1] 2,048 | | └─ReLU: 3-142 [-1, 1024, 1, 1] -- | | └─Conv2d: 3-143 [-1, 1024, 1, 1] 294,912 | | └─BatchNorm2d: 3-144 [-1, 1024, 1, 1] 2,048 | | └─ReLU: 3-145 [-1, 1024, 1, 1] -- | | └─Conv2d: 3-146 [-1, 2048, 1, 1] 2,097,152 | | └─BatchNorm2d: 3-147 [-1, 2048, 1, 1] 4,096 | | └─ReLU: 3-148 [-1, 2048, 1, 1] -- ├─AdaptiveAvgPool2d: 1-9 [-1, 2048, 1, 1] -- ├─Linear: 1-10 [-1, 100] 204,900 ========================================================================================== Total params: 23,184,804 Trainable params: 23,184,804 Non-trainable params: 0 Total mult-adds (M): 135.18 ========================================================================================== Input size (MB): 0.01 Forward/backward pass size (MB): 4.49 Params size (MB): 88.44 Estimated Total Size (MB): 92.94 ========================================================================================== CoAtNet Simple Coarse ========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== ├─Sequential: 1-1 [-1, 64, 16, 16] -- | └─Sequential: 2-1 [-1, 64, 16, 16] -- | | └─Conv2d: 3-1 [-1, 64, 16, 16] 1,728 | | └─BatchNorm2d: 3-2 [-1, 64, 16, 16] 128 | | └─GELU: 3-3 [-1, 64, 16, 16] -- | └─Sequential: 2-2 [-1, 64, 16, 16] -- | | └─Conv2d: 3-4 [-1, 64, 16, 16] 36,864 | | └─BatchNorm2d: 3-5 [-1, 64, 16, 16] 128 | | └─GELU: 3-6 [-1, 64, 16, 16] -- ├─Sequential: 1-2 [-1, 96, 8, 8] -- | └─MBConv: 2-3 [-1, 96, 8, 8] -- | | └─MaxPool2d: 3-7 [-1, 64, 8, 8] -- | | └─Conv2d: 3-8 [-1, 96, 8, 8] 6,144 | | └─PreNorm: 3-9 [-1, 96, 8, 8] 52,800 | └─MBConv: 2-4 [-1, 96, 8, 8] -- | | └─PreNorm: 3-10 [-1, 96, 8, 8] 97,536 ├─Sequential: 1-3 [-1, 192, 4, 4] -- | └─MBConv: 2-5 [-1, 192, 4, 4] -- | | └─MaxPool2d: 3-11 [-1, 96, 4, 4] -- | | └─Conv2d: 3-12 [-1, 192, 4, 4] 18,432 | | └─PreNorm: 3-13 [-1, 192, 4, 4] 134,592 | └─MBConv: 2-6 [-1, 192, 4, 4] -- | | └─PreNorm: 3-14 [-1, 192, 4, 4] 379,392 | └─MBConv: 2-7 [-1, 192, 4, 4] -- | | └─PreNorm: 3-15 [-1, 192, 4, 4] 379,392 | └─MBConv: 2-8 [-1, 192, 4, 4] -- | | └─PreNorm: 3-16 [-1, 192, 4, 4] 379,392 | └─MBConv: 2-9 [-1, 192, 4, 4] -- | | └─PreNorm: 3-17 [-1, 192, 4, 4] 379,392 | └─MBConv: 2-10 [-1, 192, 4, 4] -- | | └─PreNorm: 3-18 [-1, 192, 4, 4] 379,392 ├─Sequential: 1-4 [-1, 384, 2, 2] -- | └─Transformer: 2-11 [-1, 384, 2, 2] -- | | └─MaxPool2d: 3-19 [-1, 192, 2, 2] -- | | └─Conv2d: 3-20 [-1, 384, 2, 2] 73,728 | | └─MaxPool2d: 3-21 [-1, 192, 2, 2] -- | | └─Sequential: 3-22 [-1, 384, 2, 2] 246,600 | | └─Sequential: 3-23 [-1, 384, 2, 2] 591,744 | └─Transformer: 2-12 [-1, 384, 2, 2] -- | | └─Sequential: 3-24 [-1, 384, 2, 2] 394,440 | | └─Sequential: 3-25 [-1, 384, 2, 2] 1,182,336 | └─Transformer: 2-13 [-1, 384, 2, 2] -- | | └─Sequential: 3-26 [-1, 384, 2, 2] 394,440 | | └─Sequential: 3-27 [-1, 384, 2, 2] 1,182,336 | └─Transformer: 2-14 [-1, 384, 2, 2] -- | | └─Sequential: 3-28 [-1, 384, 2, 2] 394,440 | | └─Sequential: 3-29 [-1, 384, 2, 2] 1,182,336 | └─Transformer: 2-15 [-1, 384, 2, 2] -- | | └─Sequential: 3-30 [-1, 384, 2, 2] 394,440 | | └─Sequential: 3-31 [-1, 384, 2, 2] 1,182,336 | └─Transformer: 2-16 [-1, 384, 2, 2] -- | | └─Sequential: 3-32 [-1, 384, 2, 2] 394,440 | | └─Sequential: 3-33 [-1, 384, 2, 2] 1,182,336 | └─Transformer: 2-17 [-1, 384, 2, 2] -- | | └─Sequential: 3-34 [-1, 384, 2, 2] 394,440 | | └─Sequential: 3-35 [-1, 384, 2, 2] 1,182,336 | └─Transformer: 2-18 [-1, 384, 2, 2] -- | | └─Sequential: 3-36 [-1, 384, 2, 2] 394,440 | | └─Sequential: 3-37 [-1, 384, 2, 2] 1,182,336 | └─Transformer: 2-19 [-1, 384, 2, 2] -- | | └─Sequential: 3-38 [-1, 384, 2, 2] 394,440 | | └─Sequential: 3-39 [-1, 384, 2, 2] 1,182,336 | └─Transformer: 2-20 [-1, 384, 2, 2] -- | | └─Sequential: 3-40 [-1, 384, 2, 2] 394,440 | | └─Sequential: 3-41 [-1, 384, 2, 2] 1,182,336 | └─Transformer: 2-21 [-1, 384, 2, 2] -- | | └─Sequential: 3-42 [-1, 384, 2, 2] 394,440 | | └─Sequential: 3-43 [-1, 384, 2, 2] 1,182,336 | └─Transformer: 2-22 [-1, 384, 2, 2] -- | | └─Sequential: 3-44 [-1, 384, 2, 2] 394,440 | | └─Sequential: 3-45 [-1, 384, 2, 2] 1,182,336 | └─Transformer: 2-23 [-1, 384, 2, 2] -- | | └─Sequential: 3-46 [-1, 384, 2, 2] 394,440 | | └─Sequential: 3-47 [-1, 384, 2, 2] 1,182,336 | └─Transformer: 2-24 [-1, 384, 2, 2] -- | | └─Sequential: 3-48 [-1, 384, 2, 2] 394,440 | | └─Sequential: 3-49 [-1, 384, 2, 2] 1,182,336 ├─Sequential: 1-5 [-1, 768, 1, 1] -- | └─Transformer: 2-25 [-1, 768, 1, 1] -- | | └─MaxPool2d: 3-50 [-1, 384, 1, 1] -- | | └─Conv2d: 3-51 [-1, 768, 1, 1] 294,912 | | └─MaxPool2d: 3-52 [-1, 384, 1, 1] -- | | └─Sequential: 3-53 [-1, 768, 1, 1] 493,064 | | └─Sequential: 3-54 [-1, 768, 1, 1] 2,363,136 | └─Transformer: 2-26 [-1, 768, 1, 1] -- | | └─Sequential: 3-55 [-1, 768, 1, 1] 788,744 | | └─Sequential: 3-56 [-1, 768, 1, 1] 4,723,968 ├─AvgPool2d: 1-6 [-1, 768, 1, 1] -- ├─Linear: 1-7 [-1, 100] 76,800 ========================================================================================== Total params: 32,396,096 Trainable params: 32,396,096 Non-trainable params: 0 Total mult-adds (M): 139.38 ========================================================================================== Input size (MB): 0.01 Forward/backward pass size (MB): 0.92 Params size (MB): 123.58 Estimated Total Size (MB): 124.52 ========================================================================================== DenseNet Coarse - Complex ========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== ├─Sequential: 1-1 [-1, 1920, 1, 1] -- | └─Conv2d: 2-1 [-1, 64, 16, 16] 9,408 | └─BatchNorm2d: 2-2 [-1, 64, 16, 16] 128 | └─ReLU: 2-3 [-1, 64, 16, 16] -- | └─MaxPool2d: 2-4 [-1, 64, 8, 8] -- | └─_DenseBlock: 2-5 [-1, 256, 8, 8] -- | | └─_DenseLayer: 3-1 [-1, 32, 8, 8] 45,440 | | └─_DenseLayer: 3-2 [-1, 32, 8, 8] 49,600 | | └─_DenseLayer: 3-3 [-1, 32, 8, 8] 53,760 | | └─_DenseLayer: 3-4 [-1, 32, 8, 8] 57,920 | | └─_DenseLayer: 3-5 [-1, 32, 8, 8] 62,080 | | └─_DenseLayer: 3-6 [-1, 32, 8, 8] 66,240 | └─_Transition: 2-6 [-1, 128, 4, 4] -- | | └─BatchNorm2d: 3-7 [-1, 256, 8, 8] 512 | | └─ReLU: 3-8 [-1, 256, 8, 8] -- | | └─Conv2d: 3-9 [-1, 128, 8, 8] 32,768 | | └─AvgPool2d: 3-10 [-1, 128, 4, 4] -- | └─_DenseBlock: 2-7 [-1, 512, 4, 4] -- | | └─_DenseLayer: 3-11 [-1, 32, 4, 4] 53,760 | | └─_DenseLayer: 3-12 [-1, 32, 4, 4] 57,920 | | └─_DenseLayer: 3-13 [-1, 32, 4, 4] 62,080 | | └─_DenseLayer: 3-14 [-1, 32, 4, 4] 66,240 | | └─_DenseLayer: 3-15 [-1, 32, 4, 4] 70,400 | | └─_DenseLayer: 3-16 [-1, 32, 4, 4] 74,560 | | └─_DenseLayer: 3-17 [-1, 32, 4, 4] 78,720 | | └─_DenseLayer: 3-18 [-1, 32, 4, 4] 82,880 | | └─_DenseLayer: 3-19 [-1, 32, 4, 4] 87,040 | | └─_DenseLayer: 3-20 [-1, 32, 4, 4] 91,200 | | └─_DenseLayer: 3-21 [-1, 32, 4, 4] 95,360 | | └─_DenseLayer: 3-22 [-1, 32, 4, 4] 99,520 | └─_Transition: 2-8 [-1, 256, 2, 2] -- | | └─BatchNorm2d: 3-23 [-1, 512, 4, 4] 1,024 | | └─ReLU: 3-24 [-1, 512, 4, 4] -- | | └─Conv2d: 3-25 [-1, 256, 4, 4] 131,072 | | └─AvgPool2d: 3-26 [-1, 256, 2, 2] -- | └─_DenseBlock: 2-9 [-1, 1792, 2, 2] -- | | └─_DenseLayer: 3-27 [-1, 32, 2, 2] 70,400 | | └─_DenseLayer: 3-28 [-1, 32, 2, 2] 74,560 | | └─_DenseLayer: 3-29 [-1, 32, 2, 2] 78,720 | | └─_DenseLayer: 3-30 [-1, 32, 2, 2] 82,880 | | └─_DenseLayer: 3-31 [-1, 32, 2, 2] 87,040 | | └─_DenseLayer: 3-32 [-1, 32, 2, 2] 91,200 | | └─_DenseLayer: 3-33 [-1, 32, 2, 2] 95,360 | | └─_DenseLayer: 3-34 [-1, 32, 2, 2] 99,520 | | └─_DenseLayer: 3-35 [-1, 32, 2, 2] 103,680 | | └─_DenseLayer: 3-36 [-1, 32, 2, 2] 107,840 | | └─_DenseLayer: 3-37 [-1, 32, 2, 2] 112,000 | | └─_DenseLayer: 3-38 [-1, 32, 2, 2] 116,160 | | └─_DenseLayer: 3-39 [-1, 32, 2, 2] 120,320 | | └─_DenseLayer: 3-40 [-1, 32, 2, 2] 124,480 | | └─_DenseLayer: 3-41 [-1, 32, 2, 2] 128,640 | | └─_DenseLayer: 3-42 [-1, 32, 2, 2] 132,800 | | └─_DenseLayer: 3-43 [-1, 32, 2, 2] 136,960 | | └─_DenseLayer: 3-44 [-1, 32, 2, 2] 141,120 | | └─_DenseLayer: 3-45 [-1, 32, 2, 2] 145,280 | | └─_DenseLayer: 3-46 [-1, 32, 2, 2] 149,440 | | └─_DenseLayer: 3-47 [-1, 32, 2, 2] 153,600 | | └─_DenseLayer: 3-48 [-1, 32, 2, 2] 157,760 | | └─_DenseLayer: 3-49 [-1, 32, 2, 2] 161,920 | | └─_DenseLayer: 3-50 [-1, 32, 2, 2] 166,080 | | └─_DenseLayer: 3-51 [-1, 32, 2, 2] 170,240 | | └─_DenseLayer: 3-52 [-1, 32, 2, 2] 174,400 | | └─_DenseLayer: 3-53 [-1, 32, 2, 2] 178,560 | | └─_DenseLayer: 3-54 [-1, 32, 2, 2] 182,720 | | └─_DenseLayer: 3-55 [-1, 32, 2, 2] 186,880 | | └─_DenseLayer: 3-56 [-1, 32, 2, 2] 191,040 | | └─_DenseLayer: 3-57 [-1, 32, 2, 2] 195,200 | | └─_DenseLayer: 3-58 [-1, 32, 2, 2] 199,360 | | └─_DenseLayer: 3-59 [-1, 32, 2, 2] 203,520 | | └─_DenseLayer: 3-60 [-1, 32, 2, 2] 207,680 | | └─_DenseLayer: 3-61 [-1, 32, 2, 2] 211,840 | | └─_DenseLayer: 3-62 [-1, 32, 2, 2] 216,000 | | └─_DenseLayer: 3-63 [-1, 32, 2, 2] 220,160 | | └─_DenseLayer: 3-64 [-1, 32, 2, 2] 224,320 | | └─_DenseLayer: 3-65 [-1, 32, 2, 2] 228,480 | | └─_DenseLayer: 3-66 [-1, 32, 2, 2] 232,640 | | └─_DenseLayer: 3-67 [-1, 32, 2, 2] 236,800 | | └─_DenseLayer: 3-68 [-1, 32, 2, 2] 240,960 | | └─_DenseLayer: 3-69 [-1, 32, 2, 2] 245,120 | | └─_DenseLayer: 3-70 [-1, 32, 2, 2] 249,280 | | └─_DenseLayer: 3-71 [-1, 32, 2, 2] 253,440 | | └─_DenseLayer: 3-72 [-1, 32, 2, 2] 257,600 | | └─_DenseLayer: 3-73 [-1, 32, 2, 2] 261,760 | | └─_DenseLayer: 3-74 [-1, 32, 2, 2] 265,920 | └─_Transition: 2-10 [-1, 896, 1, 1] -- | | └─BatchNorm2d: 3-75 [-1, 1792, 2, 2] 3,584 | | └─ReLU: 3-76 [-1, 1792, 2, 2] -- | | └─Conv2d: 3-77 [-1, 896, 2, 2] 1,605,632 | | └─AvgPool2d: 3-78 [-1, 896, 1, 1] -- | └─_DenseBlock: 2-11 [-1, 1920, 1, 1] -- | | └─_DenseLayer: 3-79 [-1, 32, 1, 1] 153,600 | | └─_DenseLayer: 3-80 [-1, 32, 1, 1] 157,760 | | └─_DenseLayer: 3-81 [-1, 32, 1, 1] 161,920 | | └─_DenseLayer: 3-82 [-1, 32, 1, 1] 166,080 | | └─_DenseLayer: 3-83 [-1, 32, 1, 1] 170,240 | | └─_DenseLayer: 3-84 [-1, 32, 1, 1] 174,400 | | └─_DenseLayer: 3-85 [-1, 32, 1, 1] 178,560 | | └─_DenseLayer: 3-86 [-1, 32, 1, 1] 182,720 | | └─_DenseLayer: 3-87 [-1, 32, 1, 1] 186,880 | | └─_DenseLayer: 3-88 [-1, 32, 1, 1] 191,040 | | └─_DenseLayer: 3-89 [-1, 32, 1, 1] 195,200 | | └─_DenseLayer: 3-90 [-1, 32, 1, 1] 199,360 | | └─_DenseLayer: 3-91 [-1, 32, 1, 1] 203,520 | | └─_DenseLayer: 3-92 [-1, 32, 1, 1] 207,680 | | └─_DenseLayer: 3-93 [-1, 32, 1, 1] 211,840 | | └─_DenseLayer: 3-94 [-1, 32, 1, 1] 216,000 | | └─_DenseLayer: 3-95 [-1, 32, 1, 1] 220,160 | | └─_DenseLayer: 3-96 [-1, 32, 1, 1] 224,320 | | └─_DenseLayer: 3-97 [-1, 32, 1, 1] 228,480 | | └─_DenseLayer: 3-98 [-1, 32, 1, 1] 232,640 | | └─_DenseLayer: 3-99 [-1, 32, 1, 1] 236,800 | | └─_DenseLayer: 3-100 [-1, 32, 1, 1] 240,960 | | └─_DenseLayer: 3-101 [-1, 32, 1, 1] 245,120 | | └─_DenseLayer: 3-102 [-1, 32, 1, 1] 249,280 | | └─_DenseLayer: 3-103 [-1, 32, 1, 1] 253,440 | | └─_DenseLayer: 3-104 [-1, 32, 1, 1] 257,600 | | └─_DenseLayer: 3-105 [-1, 32, 1, 1] 261,760 | | └─_DenseLayer: 3-106 [-1, 32, 1, 1] 265,920 | | └─_DenseLayer: 3-107 [-1, 32, 1, 1] 270,080 | | └─_DenseLayer: 3-108 [-1, 32, 1, 1] 274,240 | | └─_DenseLayer: 3-109 [-1, 32, 1, 1] 278,400 | | └─_DenseLayer: 3-110 [-1, 32, 1, 1] 282,560 | └─BatchNorm2d: 2-12 [-1, 1920, 1, 1] 3,840 ├─Linear: 1-2 [-1, 100] 192,100 ========================================================================================== Total params: 18,285,028 Trainable params: 18,285,028 Non-trainable params: 0 Total mult-adds (M): 139.99 ========================================================================================== Input size (MB): 0.01 Forward/backward pass size (MB): 5.07 Params size (MB): 69.75 Estimated Total Size (MB): 74.83 ==========================================================================================
========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== ├─Sequential: 1-1 [-1, 1920, 1, 1] -- | └─Conv2d: 2-1 [-1, 64, 16, 16] 9,408 | └─BatchNorm2d: 2-2 [-1, 64, 16, 16] 128 | └─ReLU: 2-3 [-1, 64, 16, 16] -- | └─MaxPool2d: 2-4 [-1, 64, 8, 8] -- | └─_DenseBlock: 2-5 [-1, 256, 8, 8] -- | | └─_DenseLayer: 3-1 [-1, 32, 8, 8] 45,440 | | └─_DenseLayer: 3-2 [-1, 32, 8, 8] 49,600 | | └─_DenseLayer: 3-3 [-1, 32, 8, 8] 53,760 | | └─_DenseLayer: 3-4 [-1, 32, 8, 8] 57,920 | | └─_DenseLayer: 3-5 [-1, 32, 8, 8] 62,080 | | └─_DenseLayer: 3-6 [-1, 32, 8, 8] 66,240 | └─_Transition: 2-6 [-1, 128, 4, 4] -- | | └─BatchNorm2d: 3-7 [-1, 256, 8, 8] 512 | | └─ReLU: 3-8 [-1, 256, 8, 8] -- | | └─Conv2d: 3-9 [-1, 128, 8, 8] 32,768 | | └─AvgPool2d: 3-10 [-1, 128, 4, 4] -- | └─_DenseBlock: 2-7 [-1, 512, 4, 4] -- | | └─_DenseLayer: 3-11 [-1, 32, 4, 4] 53,760 | | └─_DenseLayer: 3-12 [-1, 32, 4, 4] 57,920 | | └─_DenseLayer: 3-13 [-1, 32, 4, 4] 62,080 | | └─_DenseLayer: 3-14 [-1, 32, 4, 4] 66,240 | | └─_DenseLayer: 3-15 [-1, 32, 4, 4] 70,400 | | └─_DenseLayer: 3-16 [-1, 32, 4, 4] 74,560 | | └─_DenseLayer: 3-17 [-1, 32, 4, 4] 78,720 | | └─_DenseLayer: 3-18 [-1, 32, 4, 4] 82,880 | | └─_DenseLayer: 3-19 [-1, 32, 4, 4] 87,040 | | └─_DenseLayer: 3-20 [-1, 32, 4, 4] 91,200 | | └─_DenseLayer: 3-21 [-1, 32, 4, 4] 95,360 | | └─_DenseLayer: 3-22 [-1, 32, 4, 4] 99,520 | └─_Transition: 2-8 [-1, 256, 2, 2] -- | | └─BatchNorm2d: 3-23 [-1, 512, 4, 4] 1,024 | | └─ReLU: 3-24 [-1, 512, 4, 4] -- | | └─Conv2d: 3-25 [-1, 256, 4, 4] 131,072 | | └─AvgPool2d: 3-26 [-1, 256, 2, 2] -- | └─_DenseBlock: 2-9 [-1, 1792, 2, 2] -- | | └─_DenseLayer: 3-27 [-1, 32, 2, 2] 70,400 | | └─_DenseLayer: 3-28 [-1, 32, 2, 2] 74,560 | | └─_DenseLayer: 3-29 [-1, 32, 2, 2] 78,720 | | └─_DenseLayer: 3-30 [-1, 32, 2, 2] 82,880 | | └─_DenseLayer: 3-31 [-1, 32, 2, 2] 87,040 | | └─_DenseLayer: 3-32 [-1, 32, 2, 2] 91,200 | | └─_DenseLayer: 3-33 [-1, 32, 2, 2] 95,360 | | └─_DenseLayer: 3-34 [-1, 32, 2, 2] 99,520 | | └─_DenseLayer: 3-35 [-1, 32, 2, 2] 103,680 | | └─_DenseLayer: 3-36 [-1, 32, 2, 2] 107,840 | | └─_DenseLayer: 3-37 [-1, 32, 2, 2] 112,000 | | └─_DenseLayer: 3-38 [-1, 32, 2, 2] 116,160 | | └─_DenseLayer: 3-39 [-1, 32, 2, 2] 120,320 | | └─_DenseLayer: 3-40 [-1, 32, 2, 2] 124,480 | | └─_DenseLayer: 3-41 [-1, 32, 2, 2] 128,640 | | └─_DenseLayer: 3-42 [-1, 32, 2, 2] 132,800 | | └─_DenseLayer: 3-43 [-1, 32, 2, 2] 136,960 | | └─_DenseLayer: 3-44 [-1, 32, 2, 2] 141,120 | | └─_DenseLayer: 3-45 [-1, 32, 2, 2] 145,280 | | └─_DenseLayer: 3-46 [-1, 32, 2, 2] 149,440 | | └─_DenseLayer: 3-47 [-1, 32, 2, 2] 153,600 | | └─_DenseLayer: 3-48 [-1, 32, 2, 2] 157,760 | | └─_DenseLayer: 3-49 [-1, 32, 2, 2] 161,920 | | └─_DenseLayer: 3-50 [-1, 32, 2, 2] 166,080 | | └─_DenseLayer: 3-51 [-1, 32, 2, 2] 170,240 | | └─_DenseLayer: 3-52 [-1, 32, 2, 2] 174,400 | | └─_DenseLayer: 3-53 [-1, 32, 2, 2] 178,560 | | └─_DenseLayer: 3-54 [-1, 32, 2, 2] 182,720 | | └─_DenseLayer: 3-55 [-1, 32, 2, 2] 186,880 | | └─_DenseLayer: 3-56 [-1, 32, 2, 2] 191,040 | | └─_DenseLayer: 3-57 [-1, 32, 2, 2] 195,200 | | └─_DenseLayer: 3-58 [-1, 32, 2, 2] 199,360 | | └─_DenseLayer: 3-59 [-1, 32, 2, 2] 203,520 | | └─_DenseLayer: 3-60 [-1, 32, 2, 2] 207,680 | | └─_DenseLayer: 3-61 [-1, 32, 2, 2] 211,840 | | └─_DenseLayer: 3-62 [-1, 32, 2, 2] 216,000 | | └─_DenseLayer: 3-63 [-1, 32, 2, 2] 220,160 | | └─_DenseLayer: 3-64 [-1, 32, 2, 2] 224,320 | | └─_DenseLayer: 3-65 [-1, 32, 2, 2] 228,480 | | └─_DenseLayer: 3-66 [-1, 32, 2, 2] 232,640 | | └─_DenseLayer: 3-67 [-1, 32, 2, 2] 236,800 | | └─_DenseLayer: 3-68 [-1, 32, 2, 2] 240,960 | | └─_DenseLayer: 3-69 [-1, 32, 2, 2] 245,120 | | └─_DenseLayer: 3-70 [-1, 32, 2, 2] 249,280 | | └─_DenseLayer: 3-71 [-1, 32, 2, 2] 253,440 | | └─_DenseLayer: 3-72 [-1, 32, 2, 2] 257,600 | | └─_DenseLayer: 3-73 [-1, 32, 2, 2] 261,760 | | └─_DenseLayer: 3-74 [-1, 32, 2, 2] 265,920 | └─_Transition: 2-10 [-1, 896, 1, 1] -- | | └─BatchNorm2d: 3-75 [-1, 1792, 2, 2] 3,584 | | └─ReLU: 3-76 [-1, 1792, 2, 2] -- | | └─Conv2d: 3-77 [-1, 896, 2, 2] 1,605,632 | | └─AvgPool2d: 3-78 [-1, 896, 1, 1] -- | └─_DenseBlock: 2-11 [-1, 1920, 1, 1] -- | | └─_DenseLayer: 3-79 [-1, 32, 1, 1] 153,600 | | └─_DenseLayer: 3-80 [-1, 32, 1, 1] 157,760 | | └─_DenseLayer: 3-81 [-1, 32, 1, 1] 161,920 | | └─_DenseLayer: 3-82 [-1, 32, 1, 1] 166,080 | | └─_DenseLayer: 3-83 [-1, 32, 1, 1] 170,240 | | └─_DenseLayer: 3-84 [-1, 32, 1, 1] 174,400 | | └─_DenseLayer: 3-85 [-1, 32, 1, 1] 178,560 | | └─_DenseLayer: 3-86 [-1, 32, 1, 1] 182,720 | | └─_DenseLayer: 3-87 [-1, 32, 1, 1] 186,880 | | └─_DenseLayer: 3-88 [-1, 32, 1, 1] 191,040 | | └─_DenseLayer: 3-89 [-1, 32, 1, 1] 195,200 | | └─_DenseLayer: 3-90 [-1, 32, 1, 1] 199,360 | | └─_DenseLayer: 3-91 [-1, 32, 1, 1] 203,520 | | └─_DenseLayer: 3-92 [-1, 32, 1, 1] 207,680 | | └─_DenseLayer: 3-93 [-1, 32, 1, 1] 211,840 | | └─_DenseLayer: 3-94 [-1, 32, 1, 1] 216,000 | | └─_DenseLayer: 3-95 [-1, 32, 1, 1] 220,160 | | └─_DenseLayer: 3-96 [-1, 32, 1, 1] 224,320 | | └─_DenseLayer: 3-97 [-1, 32, 1, 1] 228,480 | | └─_DenseLayer: 3-98 [-1, 32, 1, 1] 232,640 | | └─_DenseLayer: 3-99 [-1, 32, 1, 1] 236,800 | | └─_DenseLayer: 3-100 [-1, 32, 1, 1] 240,960 | | └─_DenseLayer: 3-101 [-1, 32, 1, 1] 245,120 | | └─_DenseLayer: 3-102 [-1, 32, 1, 1] 249,280 | | └─_DenseLayer: 3-103 [-1, 32, 1, 1] 253,440 | | └─_DenseLayer: 3-104 [-1, 32, 1, 1] 257,600 | | └─_DenseLayer: 3-105 [-1, 32, 1, 1] 261,760 | | └─_DenseLayer: 3-106 [-1, 32, 1, 1] 265,920 | | └─_DenseLayer: 3-107 [-1, 32, 1, 1] 270,080 | | └─_DenseLayer: 3-108 [-1, 32, 1, 1] 274,240 | | └─_DenseLayer: 3-109 [-1, 32, 1, 1] 278,400 | | └─_DenseLayer: 3-110 [-1, 32, 1, 1] 282,560 | └─BatchNorm2d: 2-12 [-1, 1920, 1, 1] 3,840 ├─Linear: 1-2 [-1, 100] 192,100 ========================================================================================== Total params: 18,285,028 Trainable params: 18,285,028 Non-trainable params: 0 Total mult-adds (M): 139.99 ========================================================================================== Input size (MB): 0.01 Forward/backward pass size (MB): 5.07 Params size (MB): 69.75 Estimated Total Size (MB): 74.83 ==========================================================================================
From the above summaries I have given each baseline model an alies and compiled the numbers of parameters they all have.
| Model | Layers (Conv+FC) | Parameters |
|---|---|---|
| Baseline | 13 | 9,198,276 |
| ResNet34 | 34 | 23,712,932 |
| ResNeXtCustom1 | 50 | 23,184,804 |
| CoAtNetCustom1 | 26 | 32,396,096 |
| DenseNetCustom1 | 171 | 20,013,928 |
criterion = nn.CrossEntropyLoss()
My learning rate decay is rather slow. However, note that I will be using momentum, which contributes to more "learning rate decay".
epoch=120
model = torch.nn.Linear(2, 1)
opt = optim.Adam(model.parameters(),lr=0.01)
lambda1 = lambda epoch: 0.976 ** epoch
scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lambda1)
lrs = []
for i in range(epoch):
opt.step()
lrs.append(opt.param_groups[0]["lr"])
scheduler.step()
plt.xlabel('epochs')
plt.ylabel('learning rate')
plt.plot(range(epoch),lrs)
plt.show()
del model, opt
valloader = DataLoader(TensorDataset(valdata_input.type('torch.FloatTensor'),valdata_label.type('torch.FloatTensor')), shuffle=False, batch_size=BATCH_SIZE)
valloader_coarse = DataLoader(TensorDataset(valdata_input.type('torch.FloatTensor'),vallabel_coarse.type('torch.FloatTensor')), shuffle=False, batch_size=BATCH_SIZE)
def _train(model,model_desc,parameters_desc,fast=False,df=hist_df,labels = traindata_label,val_loader = valloader):
global hist_df
if fast:
optimizer = optim.AdamW(model.parameters(),lr=0.001,weight_decay=0.00001)
else:
optimizer = optim.AdamW(model.parameters(),lr=0.0008,weight_decay=0.00001)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
model.to(device)
else:
model = model.to(device)
train_loss, train_acc, val_loss, val_acc, epoch_his = train(model, traindata_input,labels, optimizer, 125, criterion,val_loader,earlyStopper,scheduler=scheduler,augmentEvery=2,displayEvery=10)
df = keep_history(df,train_acc,train_loss, val_loss, val_acc, epoch_his,model_desc,parameters_desc)
return train_loss, train_acc, val_loss, val_acc, epoch_his, df
Training Baseline model
print('Baseline - VGGNet inspired')
train_loss1, train_acc1, val_loss1, val_acc1, epoch_his1, hist_df = _train(Baseline,'Baseline (VGG inspired)','9,198,276',False,df=hist_df)
Baseline - VGGNet inspired - [Epoch 10/125] | Train Loss: 1.535| Train Accuracy: 57.02 | Val Loss: 2.123 | Val Accuracy: 46.58 | Est: 76.95s - [Epoch 20/125] | Train Loss: 0.849| Train Accuracy: 76.31375 | Val Loss: 2.652 | Val Accuracy: 48.88 | Est: 76.94s - [Epoch 30/125] | Train Loss: 0.453| Train Accuracy: 87.94125 | Val Loss: 2.942 | Val Accuracy: 50.29 | Est: 77.06s - [Epoch 40/125] | Train Loss: 0.500| Train Accuracy: 87.27125 | Val Loss: 2.985 | Val Accuracy: 51.85 | Est: 77.02s - [Epoch 50/125] | Train Loss: 0.230| Train Accuracy: 94.5325 | Val Loss: 3.210 | Val Accuracy: 52.07 | Est: 77.02s - [Epoch 60/125] | Train Loss: 0.263| Train Accuracy: 93.105 | Val Loss: 3.321 | Val Accuracy: 52.06 | Est: 76.95s - [Epoch 70/125] | Train Loss: 0.242| Train Accuracy: 94.1975 | Val Loss: 3.150 | Val Accuracy: 52.97 | Est: 77.00s - [Epoch 80/125] | Train Loss: 0.254| Train Accuracy: 94.27 | Val Loss: 3.015 | Val Accuracy: 52.76 | Est: 76.96s EarlyStopper triggered at epochs: 89 *No improvement to validation loss and accuracy could be seen for the past 18 epochs Highest Val Accuracy: 53.37 @ epoch 71 | Lowest Val Loss: 1.9811143934726716 @ epoch 11
display(hist_df.iloc[0])
plot_history(train_loss1, train_acc1, val_loss1, val_acc1, epoch_his1,'Baseline')
# del train_loss1, train_acc1, val_loss1, val_acc1, epoch_his1,Baseline
Lowest Val Loss 1.981114 Highest Val Acc 53.37 Lowest Train Loss 0.095882 Highest Train Acc 97.5875 Epoch (Highest Val Acc) 71 Model Description Baseline (VGG inspired) Parameter 9,198,276 Name: 0, dtype: object
Training ResNet and ResNeXt models
print('ResNet34')
train_loss3, train_acc3, val_loss3, val_acc3, epoch_his3, hist_df = _train(ResNet34,'ResNet34','23,712,932',True,df=hist_df)
print('ResNeXtCustom1')
train_loss5, train_acc5, val_loss5, val_acc5, epoch_his5, hist_df = _train(ResNeXtCustom1,'ResNeXt-M','23,184,804',True,df=hist_df)
# del train_loss2, train_acc2, val_loss2, val_acc2, epoch_his2,ResNet18
ResNet34 - [Epoch 10/125] | Train Loss: 0.701| Train Accuracy: 81.52625 | Val Loss: 3.055 | Val Accuracy: 38.55 | Est: 178.14s - [Epoch 20/125] | Train Loss: 0.385| Train Accuracy: 90.585 | Val Loss: 3.497 | Val Accuracy: 41.33 | Est: 178.28s - [Epoch 30/125] | Train Loss: 0.046| Train Accuracy: 99.10125 | Val Loss: 3.293 | Val Accuracy: 45.22 | Est: 179.91s - [Epoch 40/125] | Train Loss: 0.116| Train Accuracy: 97.31875 | Val Loss: 3.209 | Val Accuracy: 46.76 | Est: 180.28s - [Epoch 50/125] | Train Loss: 0.214| Train Accuracy: 94.8975 | Val Loss: 3.260 | Val Accuracy: 47.0 | Est: 178.33s - [Epoch 60/125] | Train Loss: 0.174| Train Accuracy: 95.78375 | Val Loss: 3.165 | Val Accuracy: 48.19 | Est: 178.31s - [Epoch 70/125] | Train Loss: 0.339| Train Accuracy: 91.4025 | Val Loss: 3.247 | Val Accuracy: 47.84 | Est: 178.36s EarlyStopper triggered at epochs: 72 *No improvement to validation loss and accuracy could be seen for the past 18 epochs Highest Val Accuracy: 48.86 @ epoch 54 | Lowest Val Loss: 2.512834978103638 @ epoch 7 ResNeXtCustom1 - [Epoch 10/125] | Train Loss: 0.400| Train Accuracy: 89.6975 | Val Loss: 3.257 | Val Accuracy: 40.1 | Est: 210.34s - [Epoch 20/125] | Train Loss: 0.313| Train Accuracy: 92.19 | Val Loss: 3.097 | Val Accuracy: 46.3 | Est: 212.62s - [Epoch 30/125] | Train Loss: 0.281| Train Accuracy: 93.5525 | Val Loss: 3.095 | Val Accuracy: 46.32 | Est: 212.30s - [Epoch 40/125] | Train Loss: 0.123| Train Accuracy: 96.895 | Val Loss: 3.111 | Val Accuracy: 47.77 | Est: 210.74s - [Epoch 50/125] | Train Loss: 0.082| Train Accuracy: 98.02625 | Val Loss: 3.069 | Val Accuracy: 48.17 | Est: 210.72s - [Epoch 60/125] | Train Loss: 0.043| Train Accuracy: 99.0775 | Val Loss: 3.118 | Val Accuracy: 48.86 | Est: 210.53s EarlyStopper triggered at epochs: 67 *No improvement to validation loss and accuracy could be seen for the past 18 epochs Highest Val Accuracy: 48.89 @ epoch 49 | Lowest Val Loss: 2.481324315071106 @ epoch 3
display(hist_df.iloc[1])
plot_history(train_loss3, train_acc3, val_loss3, val_acc3, epoch_his3,'ResNet34')
display(hist_df.iloc[2])
plot_history(train_loss5, train_acc5, val_loss5, val_acc5, epoch_his5,'ResNeXtCustom1')
Lowest Val Loss 2.512835 Highest Val Acc 48.86 Lowest Train Loss 0.045856 Highest Train Acc 99.10125 Epoch (Highest Val Acc) 54 Model Description ResNet34 Parameter 23,712,932 Name: 0, dtype: object
Lowest Val Loss 2.481324 Highest Val Acc 48.89 Lowest Train Loss 0.034021 Highest Train Acc 99.3575 Epoch (Highest Val Acc) 49 Model Description ResNeXt-M Parameter 23,184,804 Name: 0, dtype: object
Training CoAtNet models
print('CoAtNetCustom1')
train_loss6, train_acc6, val_loss6, val_acc6, epoch_his6, hist_df = _train(CoAtNetCustom1,'CoAtNet-S','32,396,096',True,df=hist_df)
CoAtNetCustom1 - [Epoch 10/125] | Train Loss: 0.542| Train Accuracy: 84.01125 | Val Loss: 2.978 | Val Accuracy: 45.76 | Est: 268.74s - [Epoch 20/125] | Train Loss: 0.279| Train Accuracy: 92.30125 | Val Loss: 3.093 | Val Accuracy: 50.5 | Est: 266.50s - [Epoch 30/125] | Train Loss: 0.065| Train Accuracy: 98.24875 | Val Loss: 3.200 | Val Accuracy: 52.15 | Est: 266.81s - [Epoch 40/125] | Train Loss: 0.124| Train Accuracy: 97.165 | Val Loss: 3.048 | Val Accuracy: 52.74 | Est: 266.91s - [Epoch 50/125] | Train Loss: 0.122| Train Accuracy: 97.245 | Val Loss: 3.028 | Val Accuracy: 52.99 | Est: 266.80s - [Epoch 60/125] | Train Loss: 0.017| Train Accuracy: 99.6875 | Val Loss: 3.175 | Val Accuracy: 53.16 | Est: 266.68s - [Epoch 70/125] | Train Loss: 0.029| Train Accuracy: 99.395 | Val Loss: 3.122 | Val Accuracy: 54.01 | Est: 265.22s - [Epoch 80/125] | Train Loss: 0.020| Train Accuracy: 99.64625 | Val Loss: 3.046 | Val Accuracy: 54.66 | Est: 266.51s - [Epoch 90/125] | Train Loss: 0.034| Train Accuracy: 99.30625 | Val Loss: 3.057 | Val Accuracy: 54.29 | Est: 266.69s EarlyStopper triggered at epochs: 98 *No improvement to validation loss and accuracy could be seen for the past 18 epochs Highest Val Accuracy: 54.66 @ epoch 80 | Lowest Val Loss: 2.2023961901664735 @ epoch 7
display(hist_df.iloc[3])
plot_history(train_loss6, train_acc6, val_loss6, val_acc6, epoch_his6,'CoAtNetCustom1')
Lowest Val Loss 2.202396 Highest Val Acc 54.66 Lowest Train Loss 0.004562 Highest Train Acc 99.94875 Epoch (Highest Val Acc) 80 Model Description CoAtNet-S Parameter 32,396,096 Name: 0, dtype: object
Training DenseNet models
print('DenseNetCustom1')
train_loss10, train_acc10, val_loss10, val_acc10, epoch_his10, hist_df = _train(DenseNetCustom1,'DenseNet-M','20,013,928',True,df=hist_df)
DenseNetCustom1 - [Epoch 10/125] | Train Loss: 0.856| Train Accuracy: 76.645 | Val Loss: 3.126 | Val Accuracy: 38.89 | Est: 358.35s - [Epoch 20/125] | Train Loss: 0.291| Train Accuracy: 92.66875 | Val Loss: 3.058 | Val Accuracy: 45.33 | Est: 358.03s - [Epoch 30/125] | Train Loss: 0.270| Train Accuracy: 93.2875 | Val Loss: 3.005 | Val Accuracy: 47.64 | Est: 355.57s - [Epoch 40/125] | Train Loss: 0.123| Train Accuracy: 97.0375 | Val Loss: 3.013 | Val Accuracy: 48.38 | Est: 355.84s - [Epoch 50/125] | Train Loss: 0.162| Train Accuracy: 96.18625 | Val Loss: 3.087 | Val Accuracy: 47.93 | Est: 353.68s - [Epoch 60/125] | Train Loss: 0.248| Train Accuracy: 93.7425 | Val Loss: 3.084 | Val Accuracy: 48.43 | Est: 356.63s - [Epoch 70/125] | Train Loss: 0.243| Train Accuracy: 94.015 | Val Loss: 3.065 | Val Accuracy: 48.89 | Est: 357.46s - [Epoch 80/125] | Train Loss: 0.217| Train Accuracy: 94.63625 | Val Loss: 3.013 | Val Accuracy: 49.37 | Est: 357.55s - [Epoch 90/125] | Train Loss: 0.100| Train Accuracy: 97.475 | Val Loss: 2.973 | Val Accuracy: 49.71 | Est: 357.63s EarlyStopper triggered at epochs: 90 *No improvement to validation loss and accuracy could be seen for the past 18 epochs Highest Val Accuracy: 49.75 @ epoch 72 | Lowest Val Loss: 2.366958570480347 @ epoch 7
display(hist_df.iloc[4])
plot_history(train_loss10, train_acc10, val_loss10, val_acc10, epoch_his10,'DenseNetCustom1')
Lowest Val Loss 2.366959 Highest Val Acc 49.75 Lowest Train Loss 0.030436 Highest Train Acc 99.52 Epoch (Highest Val Acc) 72 Model Description DenseNet-M Parameter 20,013,928 Name: 0, dtype: object
print('Baseline - VGGNet inspired (Coarse)')
train_loss11, train_acc11, val_loss11, val_acc11, epoch_his11, hist_df = _train(BaselineCoarse,'Baseline (VGG inspired) - Coarse','9,198,276',False,df=hist_df,labels=trainlabel_coarse,val_loader=valloader_coarse)
Baseline - VGGNet inspired (Coarse) - [Epoch 10/125] | Train Loss: 0.903| Train Accuracy: 70.94 | Val Loss: 1.465 | Val Accuracy: 59.79 | Est: 76.26s - [Epoch 20/125] | Train Loss: 0.265| Train Accuracy: 91.56 | Val Loss: 1.842 | Val Accuracy: 64.08 | Est: 78.72s - [Epoch 30/125] | Train Loss: 0.271| Train Accuracy: 91.88 | Val Loss: 1.806 | Val Accuracy: 66.77 | Est: 76.36s - [Epoch 40/125] | Train Loss: 0.189| Train Accuracy: 94.23875 | Val Loss: 1.979 | Val Accuracy: 67.61 | Est: 76.52s - [Epoch 50/125] | Train Loss: 0.136| Train Accuracy: 95.94375 | Val Loss: 1.853 | Val Accuracy: 67.96 | Est: 76.38s - [Epoch 60/125] | Train Loss: 0.092| Train Accuracy: 97.36 | Val Loss: 1.906 | Val Accuracy: 67.58 | Est: 76.43s EarlyStopper triggered at epochs: 65 *No improvement to validation loss and accuracy could be seen for the past 18 epochs Highest Val Accuracy: 68.8 @ epoch 47 | Lowest Val Loss: 1.2529150903224946 @ epoch 13
display(hist_df.iloc[5])
plot_history(train_loss11, train_acc11, val_loss11, val_acc11, epoch_his11,'Baseline (VGG inspired) - Coarse')
Lowest Val Loss 1.252915 Highest Val Acc 68.8 Lowest Train Loss 0.092217 Highest Train Acc 97.36 Epoch (Highest Val Acc) 47 Model Description Baseline (VGG inspired) - Coarse Parameter 9,198,276 Name: 0, dtype: object
print('ResNeXtCustom1 (Coarse)')
train_loss12, train_acc12, val_loss12, val_acc12, epoch_his12, hist_df = _train(ResNeXtCustom1Coarse,'ResNeXt-M - Coarse','23,184,804',True,df=hist_df,labels=trainlabel_coarse,val_loader=valloader_coarse)
ResNeXtCustom1 (Coarse) - [Epoch 10/125] | Train Loss: 0.372| Train Accuracy: 88.43875 | Val Loss: 2.317 | Val Accuracy: 54.61 | Est: 209.10s - [Epoch 20/125] | Train Loss: 0.215| Train Accuracy: 93.38875 | Val Loss: 2.516 | Val Accuracy: 57.86 | Est: 209.42s - [Epoch 30/125] | Train Loss: 0.134| Train Accuracy: 96.2 | Val Loss: 2.405 | Val Accuracy: 56.64 | Est: 211.34s - [Epoch 40/125] | Train Loss: 0.153| Train Accuracy: 95.32875 | Val Loss: 2.650 | Val Accuracy: 57.73 | Est: 209.39s - [Epoch 50/125] | Train Loss: 0.163| Train Accuracy: 95.3125 | Val Loss: 2.430 | Val Accuracy: 59.71 | Est: 209.08s - [Epoch 60/125] | Train Loss: 0.060| Train Accuracy: 98.35125 | Val Loss: 2.255 | Val Accuracy: 60.49 | Est: 211.36s EarlyStopper triggered at epochs: 64 *No improvement to validation loss and accuracy could be seen for the past 18 epochs Highest Val Accuracy: 60.92 @ epoch 46 | Lowest Val Loss: 1.5609867513179778 @ epoch 7
display(hist_df.iloc[6])
plot_history(train_loss12, train_acc12, val_loss12, val_acc12, epoch_his12,'ResNeXt-M - Coarse')
Lowest Val Loss 1.560987 Highest Val Acc 60.92 Lowest Train Loss 0.026511 Highest Train Acc 99.3425 Epoch (Highest Val Acc) 46 Model Description ResNeXt-M - Coarse Parameter 23,184,804 Name: 0, dtype: object
print('CoAtNetCustom1 - Coarse')
train_loss13, train_acc13, val_loss13, val_acc13, epoch_his13, hist_df = _train(CoAtNetCustom1Coarse,'CoAtNet-S - Coarse','32,396,096',True,df=hist_df,labels=trainlabel_coarse,val_loader=valloader_coarse)
CoAtNetCustom1 - Coarse - [Epoch 10/125] | Train Loss: 0.531| Train Accuracy: 82.5525 | Val Loss: 1.981 | Val Accuracy: 53.36 | Est: 271.33s - [Epoch 20/125] | Train Loss: 0.244| Train Accuracy: 92.25125 | Val Loss: 2.402 | Val Accuracy: 58.48 | Est: 271.67s - [Epoch 30/125] | Train Loss: 0.081| Train Accuracy: 97.5225 | Val Loss: 2.413 | Val Accuracy: 61.57 | Est: 276.92s - [Epoch 40/125] | Train Loss: 0.166| Train Accuracy: 94.8975 | Val Loss: 2.905 | Val Accuracy: 62.53 | Est: 272.13s - [Epoch 50/125] | Train Loss: 0.055| Train Accuracy: 98.33875 | Val Loss: 2.843 | Val Accuracy: 63.0 | Est: 270.02s - [Epoch 60/125] | Train Loss: 0.083| Train Accuracy: 97.69125 | Val Loss: 2.462 | Val Accuracy: 62.67 | Est: 272.40s - [Epoch 70/125] | Train Loss: 0.049| Train Accuracy: 98.71625 | Val Loss: 2.479 | Val Accuracy: 63.41 | Est: 270.50s - [Epoch 80/125] | Train Loss: 0.049| Train Accuracy: 98.5975 | Val Loss: 2.400 | Val Accuracy: 63.59 | Est: 277.23s - [Epoch 90/125] | Train Loss: 0.090| Train Accuracy: 97.37125 | Val Loss: 2.241 | Val Accuracy: 63.6 | Est: 272.17s - [Epoch 100/125] | Train Loss: 0.263| Train Accuracy: 92.12125 | Val Loss: 2.187 | Val Accuracy: 63.26 | Est: 272.61s EarlyStopper triggered at epochs: 103 *No improvement to validation loss and accuracy could be seen for the past 18 epochs Highest Val Accuracy: 64.04 @ epoch 85 | Lowest Val Loss: 1.4752365529537201 @ epoch 5
display(hist_df.iloc[7])
plot_history(train_loss13, train_acc13, val_loss13, val_acc13, epoch_his13,'CoAtNet-S - Coarse')
Lowest Val Loss 1.475237 Highest Val Acc 64.04 Lowest Train Loss 0.001255 Highest Train Acc 99.995 Epoch (Highest Val Acc) 85 Model Description CoAtNet-S - Coarse Parameter 32,396,096 Name: 0, dtype: object
hist_df
| Lowest Val Loss | Highest Val Acc | Lowest Train Loss | Highest Train Acc | Epoch (Highest Val Acc) | Model Description | Parameter | |
|---|---|---|---|---|---|---|---|
| 0 | 1.981114 | 53.37 | 0.095882 | 97.58750 | 71 | Baseline (VGG inspired) | 9,198,276 |
| 0 | 2.512835 | 48.86 | 0.045856 | 99.10125 | 54 | ResNet34 | 23,712,932 |
| 0 | 2.481324 | 48.89 | 0.034021 | 99.35750 | 49 | ResNeXt-M | 23,184,804 |
| 0 | 2.202396 | 54.66 | 0.004562 | 99.94875 | 80 | CoAtNet-S | 32,396,096 |
| 0 | 2.366959 | 49.75 | 0.030436 | 99.52000 | 72 | DenseNet-M | 20,013,928 |
| 0 | 1.252915 | 68.80 | 0.092217 | 97.36000 | 47 | Baseline (VGG inspired) - Coarse | 9,198,276 |
| 0 | 1.560987 | 60.92 | 0.026511 | 99.34250 | 46 | ResNeXt-M - Coarse | 23,184,804 |
| 0 | 1.475237 | 64.04 | 0.001255 | 99.99500 | 85 | CoAtNet-S - Coarse | 32,396,096 |
Seems like CoAtNet performs the best for fine labels while the Baseline perform the best for coarse labelled data. I suspect that CoAtNet actually works best for both fine and coarse labelled data. However, it did not perform as well as it is overfitting too much compared to my simpler Baseline model (Training accuracy @ 97+%). Highest training accuracy is always very high (99+%) for our complex models which means our weight decay or more importantly our data augmentation have to be much better and stronger before we see my complex model overtake my baseline model here. I believe with much stronger data augmentation CoAtNet can out perform our Baseline model with no problem. For model improvement, much consideration will be put into augmenting our data.
Can we still do hyperparameter tuning even when our models become more and more complex with the current computational power I have? All in all yes. However, unlike Part A, my hyperparameter tuning method has to be more efficient. In Part A (FashionMNIST), coarse to fine is perfect as our model is not too complex and I can search about 150 trials every 10 hours. However, with millions upon millions of parameters a more efficient hyperparameter tuning technique is needed.
What makes hyperparameter tuning a complex model extremely time and computationally intensive? It is the number of epochs, FLOPS (computational power) and time it takes to reach anywhere close to the local minima. Say, I need an average 250 epochs to reach the local minima of my complex model, and hyperparameter tuning 150 trials for 250 epochs would take extremely long that it would become impractical, even for a hyperparameter technique like coarse to fine in Part A [Kiddon et al., 2010]. However, I can cut down the number of epochs, FLOPS and time of each trial by cutting down the most time intensive part of training a complex model which is the start. If I train a model with a slightly lower learning rate with relatively higher weight decay for the first 200 epoch, I can just simply start my hyperparameter tuning at epoch 200. This will cut down my hyperparameter tuning time by skipping that initial (time consuming) 200 epochs. However some rules have to be applied as followed.
Augmentation have to be reapplied regularly, if augmentation is not done after transfering the weights of my model at epoch 200. The local minima will remain the same, as the training set did not change at all, regardless whether I have applied my augmentation before training. Secondly, I will use a relatively high weight decay during the first 200 epoch, this allow the weights during hyperparameter tuning to fit more to its new parameters that I set.
Models we are tuning:
**What are some ways to improve our models for fine/coarse dataset?**
1) Choosing right augmentation & sampling/replication method:
2) Tuning model complexity:
3) Hyperparameter tuning (see Discussion):
4) Transfer learning (between my own models only) - [no external weights]:
SAM optimizer will used as from seen Part A and the original technical paper [Foret el at., 2020], it has shown consitently better generalization over SGD optimizer.
!git clone https://github.com/davda54/sam.git
Cloning into 'sam'... remote: Enumerating objects: 200, done. remote: Counting objects: 100% (96/96), done. remote: Compressing objects: 100% (36/36), done. remote: Total 200 (delta 73), reused 66 (delta 60), pack-reused 104 Receiving objects: 100% (200/200), 659.65 KiB | 3.58 MiB/s, done. Resolving deltas: 100% (95/95), done.
from sam.sam import SAM
df_augmentation = pd.DataFrame([],columns=['Lowest Val Loss','Highest Val Acc','Lowest Train Loss','Epoch (Highest Val Acc)', 'Model Description', 'Parameter'])
New function has to be written for SAM optimizer
def train_sam(model, input_, label_ , optimizer, NUM_EPOCHS, criterion, valloader=None,earlystopper=None,showEpoch=True,scheduler=None,displayEvery=1,augmentEvery=None,augCompose1_=imageAutoAugment,augCompose2_=imageAutoAugment,sampleMode=2,augBatch=1250,classNo=100):
train_loss_his = []
train_accuracy_his = []
val_loss_his = []
val_accuracy_his = []
epoch_his = []
earlystopper.reset()
model.train()
for epoch in range(NUM_EPOCHS):
t0 = time.time()
total = 0
correct = 0
running_loss = 0.0
if augmentEvery != None:
if epoch % augmentEvery == 0 and sampleMode == 1:
input_rep, label_rep = replicateData1(traindata=input_,augCompose=augCompose1_,augBatch=augBatch,trainlabel=label_,classNum_=classNo)
loader = DataLoader(TensorDataset(input_rep.type('torch.FloatTensor'),label_rep.type('torch.FloatTensor')), shuffle=False, batch_size=BATCH_SIZE)
# The augmented dataset is already shuffled
elif epoch % augmentEvery == 0 and sampleMode == 2:
input_rep, label_rep = replicateData2(traindata=input_,augCompose=augCompose1_,augBatch=augBatch,trainlabel=label_)
loader = DataLoader(TensorDataset(input_rep.type('torch.FloatTensor'),label_rep.type('torch.FloatTensor')), shuffle=False, batch_size=BATCH_SIZE)
elif epoch % augmentEvery == 0 and sampleMode == 3:
input_rep, label_rep = replicateData3(traindata=input_,augCompose1=augCompose1_,augCompose2=augCompose2_,augBatch=augBatch,trainlabel=label_)
elif epoch % augmentEvery == 0 and sampleMode == 4:
input_rep, label_rep = replicateData4(traindata=input_,augCompose1=augCompose1_,augCompose2=augCompose2_,augBatch=augBatch,trainlabel=label_)
loader = DataLoader(TensorDataset(input_rep.type('torch.FloatTensor'),label_rep.type('torch.FloatTensor')), shuffle=False, batch_size=BATCH_SIZE)
elif epoch % augmentEvery == 0 and sampleMode == 5:
input_rep, label_rep = replicateData5(augCompose1=augCompose1_,augBatch=augBatch,trainlabel=label_)
loader = DataLoader(TensorDataset(input_rep.type('torch.FloatTensor'),label_rep.type('torch.FloatTensor')), shuffle=False, batch_size=BATCH_SIZE)
else:
loader = DataLoader(TensorDataset(input_.type('torch.FloatTensor'),label.type('torch.FloatTensor')), shuffle=True, batch_size=BATCH_SIZE)
for data in loader:
img, label = data
img = img.to(device)
label = label.to(device)
outputs = model(img)
loss = criterion(outputs, label) # use this loss for any training statistics
loss.backward()
optimizer.first_step(zero_grad=True)
outputs = model(img)
# second forward-backward pass
criterion(outputs, label).backward() # make sure to do a full forward pass
optimizer.second_step(zero_grad=True)
correct += (outputs.argmax(-1) == label.argmax(-1)).sum().item()
total += label.size(0)
running_loss += loss.item()
scheduler.step()
accuracy = 100 * correct / total
loss = running_loss / len(loader)
train_loss_his.append(loss)
train_accuracy_his.append(accuracy)
val_loss, val_accuracy = validate(model, criterion, valloader)
val_loss_his.append(val_loss)
val_accuracy_his.append(val_accuracy)
epoch_his.append(epoch+1)
if showEpoch and (epoch+1)%displayEvery == 0:
print(f'- [Epoch {epoch+1}/{NUM_EPOCHS}] | Train Loss: {loss:.3f}| Train Accuracy: {accuracy} | Val Loss: {val_loss:.3f} | Val Accuracy: {val_accuracy} | Est: {(time.time() - t0)*displayEvery:.2f}s')
if earlyStopper(val_accuracy,val_loss):
print(f'EarlyStopper triggered at epochs: {epoch+1} \n*No improvement to validation loss and accuracy could be seen for the past {earlyStopper.patience} epochs')
break
print(f'Highest Val Accuracy: {max(val_accuracy_his)} @ epoch {epoch_his[val_accuracy_his.index(max(val_accuracy_his))]} | Lowest Val Loss: {min(val_loss_his)} @ epoch {epoch_his[val_loss_his.index(min(val_loss_his))]}')
return train_loss_his, train_accuracy_his, val_loss_his, val_accuracy_his,epoch_his
Initilizing our model for fine and coarse data
def _train(model_desc,augCompose1, augCompose2,df,sampleMode=3,fine=True,lr_=0.01):
global df_augmentation
if fine:
parameters_desc = "32,396,096"
model = CoAtNet((32, 32), 3, [2, 2, 6, 14, 2], [64, 96, 192, 384, 768], num_classes=100)
base_optimizer = optim.SGD
optimizer = SAM(model.parameters(), base_optimizer, lr = lr_, momentum=0.90, weight_decay=0.0000025)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
model = nn.DataParallel(model)
model.to(device)
train_loss, train_acc, val_loss, val_acc, epoch_his = train_sam(model, traindata_input,traindata_label, optimizer, 125, criterion,valloader,earlyStopper,augCompose1_=augCompose1, augCompose2_=augCompose2,scheduler=scheduler,augmentEvery=1,displayEvery=10,sampleMode=sampleMode)
else:
parameters_desc = "9,198,276"
model = SimpleNet([2,3,4,4],3,32,2,2048,20)
base_optimizer = optim.SGD
optimizer = SAM(model.parameters(), base_optimizer, lr = lr_, momentum=0.90, weight_decay=0.0000025)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
model = nn.DataParallel(model)
model.to(device)
train_loss, train_acc, val_loss, val_acc, epoch_his = train_sam(model, traindata_input,trainlabel_coarse, optimizer, 125, criterion,valloader_coarse,earlyStopper,augCompose1_=augCompose1, augCompose2_=augCompose2,scheduler=scheduler,augmentEvery=1,displayEvery=10,sampleMode=sampleMode,classNo=20)
df = keep_history(df,train_acc,train_loss, val_loss, val_acc, epoch_his,model_desc,parameters_desc)
return train_loss, train_acc, val_loss, val_acc, epoch_his, df
I have made some sampling/replication methods in my train_sam function:
Combinations of augmentation to use will be based on intuition instead of permutation...permutation of every augmentation set will be too time consuming.
sampleMode0: Try all different sets of augmentation I currently have on fine-labelled data
_,_,_,_,_,df_augmentation = _train("CoAtNet-AutoAugment sample:0", imageAutoAugment2,None, df_augmentation,0,True,lr_=0.014)
- [Epoch 10/125] | Train Loss: 2.287| Train Accuracy: 37.295 | Val Loss: 2.998 | Val Accuracy: 30.36 | Est: 287.67s - [Epoch 20/125] | Train Loss: 1.742| Train Accuracy: 50.2675 | Val Loss: 2.422 | Val Accuracy: 40.54 | Est: 287.88s - [Epoch 30/125] | Train Loss: 0.748| Train Accuracy: 77.175 | Val Loss: 2.281 | Val Accuracy: 45.45 | Est: 269.64s - [Epoch 40/125] | Train Loss: 1.206| Train Accuracy: 66.5575 | Val Loss: 2.535 | Val Accuracy: 40.51 | Est: 299.86s - [Epoch 50/125] | Train Loss: 0.710| Train Accuracy: 79.505 | Val Loss: 2.469 | Val Accuracy: 42.86 | Est: 277.39s - [Epoch 60/125] | Train Loss: 0.331| Train Accuracy: 89.025 | Val Loss: 2.255 | Val Accuracy: 47.9 | Est: 294.00s - [Epoch 70/125] | Train Loss: 0.523| Train Accuracy: 83.77 | Val Loss: 2.217 | Val Accuracy: 48.6 | Est: 281.45s - [Epoch 80/125] | Train Loss: 0.814| Train Accuracy: 76.7725 | Val Loss: 2.815 | Val Accuracy: 38.67 | Est: 300.41s - [Epoch 90/125] | Train Loss: 0.517| Train Accuracy: 83.9725 | Val Loss: 2.162 | Val Accuracy: 48.8 | Est: 293.23s - [Epoch 100/125] | Train Loss: 0.826| Train Accuracy: 80.02 | Val Loss: 2.196 | Val Accuracy: 48.87 | Est: 292.82s - [Epoch 110/125] | Train Loss: 0.654| Train Accuracy: 80.44 | Val Loss: 2.301 | Val Accuracy: 46.69 | Est: 271.18s EarlyStopper triggered at epochs: 116 *No improvement to validation loss and accuracy could be seen for the past 18 epochs Highest Val Accuracy: 49.41 @ epoch 98 | Lowest Val Loss: 2.152133846282959 @ epoch 96
_,_,_,_,_,df_augmentation = _train("CoAtNet-AugMix sample:0", imageAugMix2,None, df_augmentation,0,True,lr_=0.005)
- [Epoch 10/125] | Train Loss: 1.667| Train Accuracy: 48.3825 | Val Loss: 2.860 | Val Accuracy: 30.61 | Est: 333.81s - [Epoch 20/125] | Train Loss: 0.439| Train Accuracy: 81.215 | Val Loss: 3.357 | Val Accuracy: 31.07 | Est: 327.65s - [Epoch 30/125] | Train Loss: 0.347| Train Accuracy: 86.895 | Val Loss: 3.262 | Val Accuracy: 33.59 | Est: 366.19s - [Epoch 40/125] | Train Loss: 0.141| Train Accuracy: 94.4025 | Val Loss: 3.276 | Val Accuracy: 34.35 | Est: 334.79s - [Epoch 50/125] | Train Loss: 0.130| Train Accuracy: 94.48 | Val Loss: 3.238 | Val Accuracy: 34.89 | Est: 310.43s - [Epoch 60/125] | Train Loss: 0.153| Train Accuracy: 93.7425 | Val Loss: 3.241 | Val Accuracy: 34.8 | Est: 346.01s - [Epoch 70/125] | Train Loss: 0.080| Train Accuracy: 96.5 | Val Loss: 3.168 | Val Accuracy: 35.44 | Est: 347.28s - [Epoch 80/125] | Train Loss: 0.143| Train Accuracy: 93.9775 | Val Loss: 3.180 | Val Accuracy: 35.79 | Est: 356.59s - [Epoch 90/125] | Train Loss: 0.173| Train Accuracy: 93.4125 | Val Loss: 3.144 | Val Accuracy: 36.12 | Est: 343.30s - [Epoch 100/125] | Train Loss: 0.100| Train Accuracy: 95.69 | Val Loss: 3.164 | Val Accuracy: 35.53 | Est: 343.24s EarlyStopper triggered at epochs: 108 *No improvement to validation loss and accuracy could be seen for the past 18 epochs Highest Val Accuracy: 36.12 @ epoch 90 | Lowest Val Loss: 2.839644396305084 @ epoch 8
Epochs is switched from 125 epochs to 70 epochs as I realised anything improvement after 70 epochs is quite negliable.
_,_,_,_,_,df_augmentation = _train("CoAtNet-ImageNet sample:0", imageAutoAugmentImageNet2,None, df_augmentation,0,True,lr_=0.008)
- [Epoch 10/70] | Train Loss: 2.791| Train Accuracy: 26.78 | Val Loss: 2.853 | Val Accuracy: 29.48 | Est: 343.69s - [Epoch 20/70] | Train Loss: 2.194| Train Accuracy: 39.24 | Val Loss: 2.521 | Val Accuracy: 37.46 | Est: 338.53s - [Epoch 30/70] | Train Loss: 1.657| Train Accuracy: 51.21 | Val Loss: 2.309 | Val Accuracy: 41.61 | Est: 301.34s - [Epoch 40/70] | Train Loss: 1.153| Train Accuracy: 66.8175 | Val Loss: 2.355 | Val Accuracy: 42.2 | Est: 325.92s - [Epoch 50/70] | Train Loss: 1.426| Train Accuracy: 58.335 | Val Loss: 2.809 | Val Accuracy: 36.88 | Est: 309.64s - [Epoch 60/70] | Train Loss: 0.976| Train Accuracy: 70.35 | Val Loss: 2.247 | Val Accuracy: 45.52 | Est: 332.55s - [Epoch 70/70] | Train Loss: 0.720| Train Accuracy: 76.775 | Val Loss: 2.269 | Val Accuracy: 45.66 | Est: 287.11s Highest Val Accuracy: 46.17 @ epoch 68 | Lowest Val Loss: 2.238827347755432 @ epoch 42
_,_,_,_,_,df_augmentation = _train("CoAtNet-CustomAugment sample:0", imageColorTrans2,None, df_augmentation,0,True)
- [Epoch 10/70] | Train Loss: 2.357| Train Accuracy: 34.71 | Val Loss: 2.756 | Val Accuracy: 31.97 | Est: 315.47s - [Epoch 20/70] | Train Loss: 1.576| Train Accuracy: 52.715 | Val Loss: 2.584 | Val Accuracy: 37.36 | Est: 317.87s - [Epoch 30/70] | Train Loss: 1.146| Train Accuracy: 63.215 | Val Loss: 2.607 | Val Accuracy: 38.56 | Est: 316.01s - [Epoch 40/70] | Train Loss: 0.505| Train Accuracy: 80.555 | Val Loss: 2.625 | Val Accuracy: 40.85 | Est: 315.92s - [Epoch 50/70] | Train Loss: 0.379| Train Accuracy: 86.13 | Val Loss: 2.624 | Val Accuracy: 42.0 | Est: 320.19s - [Epoch 60/70] | Train Loss: 0.524| Train Accuracy: 82.5 | Val Loss: 3.046 | Val Accuracy: 36.35 | Est: 315.09s - [Epoch 70/70] | Train Loss: 0.468| Train Accuracy: 83.3225 | Val Loss: 2.898 | Val Accuracy: 39.41 | Est: 316.28s Highest Val Accuracy: 44.15 @ epoch 69 | Lowest Val Loss: 2.335602509975433 @ epoch 19
_,_,_,_,_,df_augmentation = _train("CoAtNet-AutoAugment+CutMix sample:1", imageAutoAugment,None, df_augmentation,1,True,lr_=0.0125)
- [Epoch 10/70] | Train Loss: 1.682| Train Accuracy: 59.98416666666667 | Val Loss: 1.946 | Val Accuracy: 49.5 | Est: 773.77s - [Epoch 20/70] | Train Loss: 1.364| Train Accuracy: 68.8 | Val Loss: 1.823 | Val Accuracy: 53.44 | Est: 775.08s - [Epoch 30/70] | Train Loss: 1.116| Train Accuracy: 75.86083333333333 | Val Loss: 1.769 | Val Accuracy: 54.96 | Est: 779.38s - [Epoch 40/70] | Train Loss: 1.038| Train Accuracy: 77.82416666666667 | Val Loss: 1.764 | Val Accuracy: 55.51 | Est: 768.73s - [Epoch 50/70] | Train Loss: 1.192| Train Accuracy: 73.16666666666667 | Val Loss: 1.721 | Val Accuracy: 55.82 | Est: 807.53s - [Epoch 60/70] | Train Loss: 1.009| Train Accuracy: 79.55916666666667 | Val Loss: 1.706 | Val Accuracy: 56.87 | Est: 773.27s - [Epoch 70/70] | Train Loss: 0.998| Train Accuracy: 79.66416666666667 | Val Loss: 1.702 | Val Accuracy: 56.61 | Est: 776.10s Highest Val Accuracy: 57.16 @ epoch 67 | Lowest Val Loss: 1.6873259544372559 @ epoch 62
output for this cell was accidentally removed here
_,_,_,_,_,df_augmentation = _train("CoAtNet-ImageNet+CutMix sample:1", imageAutoAugmentImageNet,None, df_augmentation,1,True,lr_=0.014)
_,_,_,_,_,df_augmentation = _train("CoAtNet-AutoAugment+ImageNet sample:3", imageAutoAugment,imageAutoAugmentImageNet, df_augmentation,1,True,lr_=0.008)
- [Epoch 10/70] | Train Loss: 2.006| Train Accuracy: 51.685 | Val Loss: 2.021 | Val Accuracy: 47.22 | Est: 799.56s - [Epoch 20/70] | Train Loss: 1.418| Train Accuracy: 68.05416666666666 | Val Loss: 1.891 | Val Accuracy: 51.05 | Est: 756.33s - [Epoch 30/70] | Train Loss: 1.283| Train Accuracy: 71.29583333333333 | Val Loss: 1.836 | Val Accuracy: 52.88 | Est: 783.18s - [Epoch 40/70] | Train Loss: 1.239| Train Accuracy: 71.75333333333333 | Val Loss: 1.791 | Val Accuracy: 53.74 | Est: 762.96s - [Epoch 50/70] | Train Loss: 1.119| Train Accuracy: 77.13833333333334 | Val Loss: 1.788 | Val Accuracy: 54.57 | Est: 773.15s - [Epoch 60/70] | Train Loss: 1.086| Train Accuracy: 76.675 | Val Loss: 1.767 | Val Accuracy: 54.83 | Est: 774.90s - [Epoch 70/70] | Train Loss: 1.069| Train Accuracy: 76.88666666666667 | Val Loss: 1.748 | Val Accuracy: 55.14 | Est: 789.00s Highest Val Accuracy: 55.29 @ epoch 61 | Lowest Val Loss: 1.748326873779297 @ epoch 70
_,_,_,_,_,df_augmentation = _train("CoAtNet-AutoAugment+ImageNet sample:4", imageAutoAugment,imageAutoAugmentImageNet, df_augmentation,4,True,lr_=0.008)
- [Epoch 10/70] | Train Loss: 2.067| Train Accuracy: 42.42 | Val Loss: 2.157 | Val Accuracy: 43.24 | Est: 567.63s - [Epoch 20/70] | Train Loss: 1.482| Train Accuracy: 57.215 | Val Loss: 2.007 | Val Accuracy: 47.93 | Est: 594.95s - [Epoch 30/70] | Train Loss: 0.934| Train Accuracy: 71.0875 | Val Loss: 1.973 | Val Accuracy: 50.31 | Est: 588.70s - [Epoch 40/70] | Train Loss: 0.787| Train Accuracy: 75.5 | Val Loss: 1.925 | Val Accuracy: 51.61 | Est: 597.81s - [Epoch 50/70] | Train Loss: 0.854| Train Accuracy: 74.785 | Val Loss: 1.897 | Val Accuracy: 52.39 | Est: 617.75s - [Epoch 60/70] | Train Loss: 0.705| Train Accuracy: 77.4475 | Val Loss: 1.913 | Val Accuracy: 52.42 | Est: 570.34s - [Epoch 70/70] | Train Loss: 0.679| Train Accuracy: 78.16375 | Val Loss: 1.862 | Val Accuracy: 53.52 | Est: 569.28s Highest Val Accuracy: 53.67 @ epoch 67 | Lowest Val Loss: 1.853470641374588 @ epoch 67
_,_,_,_,_,df_augmentation = _train("CoAtNet-AutoAugment sample:5", imageAutoAugment2,None, df_augmentation,5,True,lr_=0.008)
- [Epoch 20/125] | Train Loss: 1.587| Train Accuracy: 54.36375 | Val Loss: 1.808 | Val Accuracy: 52.15 | Est: 567.60s - [Epoch 30/125] | Train Loss: 1.112| Train Accuracy: 65.52125 | Val Loss: 1.804 | Val Accuracy: 53.23 | Est: 544.91s - [Epoch 40/125] | Train Loss: 0.864| Train Accuracy: 72.82625 | Val Loss: 1.782 | Val Accuracy: 54.98 | Est: 568.62s - [Epoch 50/125] | Train Loss: 0.971| Train Accuracy: 71.1525 | Val Loss: 1.799 | Val Accuracy: 54.8 | Est: 560.19s - [Epoch 60/125] | Train Loss: 0.301| Train Accuracy: 88.78875 | Val Loss: 1.823 | Val Accuracy: 55.65 | Est: 564.30s - [Epoch 70/125] | Train Loss: 0.435| Train Accuracy: 84.74875 | Val Loss: 1.801 | Val Accuracy: 55.8 | Est: 586.98s - [Epoch 80/125] | Train Loss: 0.938| Train Accuracy: 71.675 | Val Loss: 1.798 | Val Accuracy: 56.36 | Est: 535.43s - [Epoch 90/125] | Train Loss: 0.599| Train Accuracy: 80.85125 | Val Loss: 1.797 | Val Accuracy: 55.91 | Est: 554.43s - [Epoch 100/125] | Train Loss: 0.630| Train Accuracy: 78.96 | Val Loss: 1.815 | Val Accuracy: 56.32 | Est: 567.32s EarlyStopper triggered at epochs: 106 *No improvement to validation loss and accuracy could be seen for the past 18 epochs Highest Val Accuracy: 56.5 @ epoch 88 | Lowest Val Loss: 1.7661678910255432 @ epoch 55
_,_,_,_,_,df_augmentation = _train("CoAtNet-ImageNet sample:5", imageAutoAugmentImageNet2,None, df_augmentation,5,True,lr_=0.01)
- [Epoch 20/125] | Train Loss: 1.742| Train Accuracy: 48.525 | Val Loss: 1.847 | Val Accuracy: 50.79 | Est: 630.41s - [Epoch 30/125] | Train Loss: 1.212| Train Accuracy: 61.99625 | Val Loss: 1.760 | Val Accuracy: 53.85 | Est: 612.16s - [Epoch 40/125] | Train Loss: 1.122| Train Accuracy: 65.38125 | Val Loss: 1.765 | Val Accuracy: 54.53 | Est: 642.32s - [Epoch 70/125] | Train Loss: 0.648| Train Accuracy: 78.71 | Val Loss: 1.790 | Val Accuracy: 55.62 | Est: 678.14s - [Epoch 90/125] | Train Loss: 0.623| Train Accuracy: 79.265 | Val Loss: 1.787 | Val Accuracy: 56.25 | Est: 629.20s - [Epoch 100/125] | Train Loss: 0.612| Train Accuracy: 79.35 | Val Loss: 1.786 | Val Accuracy: 56.05 | Est: 603.89s - [Epoch 110/125] | Train Loss: 0.649| Train Accuracy: 77.97375 | Val Loss: 1.801 | Val Accuracy: 55.91 | Est: 653.87s - [Epoch 120/125] | Train Loss: 0.735| Train Accuracy: 74.7725 | Val Loss: 1.815 | Val Accuracy: 55.97 | Est: 622.59s EarlyStopper triggered at epochs: 120 *No improvement to validation loss and accuracy could be seen for the past 18 epochs Highest Val Accuracy: 56.72 @ epoch 102 | Lowest Val Loss: 1.7460790634155274 @ epoch 36
_,_,_,_,_,df_augmentation = _train("CoAtNet-AugMix sample:5", imageAugMix2,None, df_augmentation,5,True,lr_=0.0085)
- [Epoch 10/125] | Train Loss: 1.327| Train Accuracy: 56.74 | Val Loss: 2.359 | Val Accuracy: 40.98 | Est: 685.58s - [Epoch 20/125] | Train Loss: 0.365| Train Accuracy: 85.13125 | Val Loss: 2.650 | Val Accuracy: 42.86 | Est: 694.14s - [Epoch 30/125] | Train Loss: 0.199| Train Accuracy: 91.56125 | Val Loss: 2.688 | Val Accuracy: 43.5 | Est: 673.29s - [Epoch 40/125] | Train Loss: 0.155| Train Accuracy: 93.25875 | Val Loss: 2.655 | Val Accuracy: 44.67 | Est: 651.22s - [Epoch 50/125] | Train Loss: 0.131| Train Accuracy: 94.5875 | Val Loss: 2.626 | Val Accuracy: 45.1 | Est: 682.66s - [Epoch 70/125] | Train Loss: 0.109| Train Accuracy: 95.34625 | Val Loss: 2.588 | Val Accuracy: 45.34 | Est: 716.09s
df_augmentation
| Lowest Val Loss | Highest Val Acc | Lowest Train Loss | Epoch (Highest Val Acc) | Model Description | Parameter | Highest Train Acc | |
|---|---|---|---|---|---|---|---|
| 0 | 2.152134 | 49.41 | 0.077599 | 98 | CoAtNet-AutoAugment sample:0 | 32,396,096 | 97.837500 |
| 0 | 2.839644 | 36.12 | 0.072459 | 90 | CoAtNet-AugMix sample:0 | 32,396,096 | 96.557500 |
| 0 | 2.238827 | 46.17 | 0.626822 | 68 | CoAtNet-ImageNet sample:0 | 32,396,096 | 79.947500 |
| 0 | 2.335603 | 44.15 | 0.113020 | 69 | CoAtNet-CustomAugment sample:0 | 32,396,096 | 94.617500 |
| 0 | 1.687326 | 57.16 | 0.859597 | 67 | CoAtNet-AutoAugment+CutMix sample:1 | 32,396,096 | 83.082500 |
| 0 | 1.718855 | 55.71 | 1.011111 | 55 | CoAtNet-ImageNet+CutMix sample:1 | 32,396,096 | 79.366667 |
| 0 | 1.748327 | 55.29 | 0.916276 | 61 | CoAtNet-AutoAugment+ImageNet sample:3 | 32,396,096 | 80.396667 |
| 0 | 1.853471 | 53.67 | 0.602639 | 67 | CoAtNet-AutoAugment+ImageNet sample:4 | 32,396,096 | 80.353750 |
| 0 | 1.766168 | 56.50 | 0.301244 | 88 | CoAtNet-AutoAugment sample:5 | 32,396,096 | 88.788750 |
| 0 | 1.746079 | 56.72 | 0.435966 | 102 | CoAtNet-ImageNet sample:5 | 32,396,096 | 85.050000 |
| 0 | 2.358972 | 45.45 | 0.089996 | 53 | CoAtNet-AugMix sample:5 | 32,396,096 | 95.651250 |
df_augmentation2 = pd.DataFrame([],columns=['Lowest Val Loss','Highest Val Acc','Lowest Train Loss','Epoch (Highest Val Acc)', 'Model Description', 'Parameter'])
_,_,_,_,_,df_augmentation2 = _train("SimpleNet-AutoAugment+CutMix sample:1 (Coarse)", imageAutoAugment,None, df_augmentation2,1,False,lr_=0.0125)
- [Epoch 10/125] | Train Loss: 1.260| Train Accuracy: 62.505 | Val Loss: 1.086 | Val Accuracy: 65.28 | Est: 271.15s - [Epoch 20/125] | Train Loss: 0.936| Train Accuracy: 71.52666666666667 | Val Loss: 1.009 | Val Accuracy: 70.25 | Est: 260.80s - [Epoch 30/125] | Train Loss: 0.805| Train Accuracy: 76.92833333333333 | Val Loss: 1.002 | Val Accuracy: 71.47 | Est: 278.16s - [Epoch 40/125] | Train Loss: 0.884| Train Accuracy: 75.04833333333333 | Val Loss: 0.999 | Val Accuracy: 71.95 | Est: 260.82s - [Epoch 50/125] | Train Loss: 0.770| Train Accuracy: 79.7675 | Val Loss: 1.045 | Val Accuracy: 71.91 | Est: 271.93s - [Epoch 60/125] | Train Loss: 0.581| Train Accuracy: 84.30833333333334 | Val Loss: 0.989 | Val Accuracy: 73.09 | Est: 256.20s - [Epoch 70/125] | Train Loss: 0.666| Train Accuracy: 81.9725 | Val Loss: 0.977 | Val Accuracy: 73.37 | Est: 259.20s - [Epoch 80/125] | Train Loss: 0.603| Train Accuracy: 85.5975 | Val Loss: 0.986 | Val Accuracy: 73.54 | Est: 265.66s - [Epoch 90/125] | Train Loss: 0.677| Train Accuracy: 82.295 | Val Loss: 0.968 | Val Accuracy: 73.6 | Est: 273.39s - [Epoch 100/125] | Train Loss: 0.664| Train Accuracy: 83.6275 | Val Loss: 0.973 | Val Accuracy: 73.37 | Est: 267.61s - [Epoch 110/125] | Train Loss: 0.696| Train Accuracy: 80.96083333333333 | Val Loss: 0.975 | Val Accuracy: 73.38 | Est: 237.83s - [Epoch 120/125] | Train Loss: 0.623| Train Accuracy: 83.30166666666666 | Val Loss: 0.946 | Val Accuracy: 73.97 | Est: 269.65s EarlyStopper triggered at epochs: 121 *No improvement to validation loss and accuracy could be seen for the past 18 epochs Highest Val Accuracy: 74.06 @ epoch 103 | Lowest Val Loss: 0.9396569907665253 @ epoch 86
_,_,_,_,_,df_augmentation2 = _train("SimpleNet-ImageNet+CutMix sample:1 (Coarse)", imageAutoAugmentImageNet,None, df_augmentation2,1,False,lr_=0.0125)
- [Epoch 10/125] | Train Loss: 1.399| Train Accuracy: 56.59916666666667 | Val Loss: 1.091 | Val Accuracy: 66.07 | Est: 307.97s - [Epoch 20/125] | Train Loss: 0.960| Train Accuracy: 73.49333333333334 | Val Loss: 0.996 | Val Accuracy: 70.02 | Est: 297.86s - [Epoch 30/125] | Train Loss: 0.956| Train Accuracy: 73.64083333333333 | Val Loss: 0.968 | Val Accuracy: 71.97 | Est: 301.09s - [Epoch 40/125] | Train Loss: 0.791| Train Accuracy: 78.64583333333333 | Val Loss: 0.976 | Val Accuracy: 72.22 | Est: 309.35s - [Epoch 50/125] | Train Loss: 0.676| Train Accuracy: 81.79 | Val Loss: 1.012 | Val Accuracy: 72.45 | Est: 318.54s - [Epoch 60/125] | Train Loss: 0.587| Train Accuracy: 86.15333333333334 | Val Loss: 0.994 | Val Accuracy: 72.98 | Est: 269.85s - [Epoch 70/125] | Train Loss: 0.698| Train Accuracy: 82.6325 | Val Loss: 0.984 | Val Accuracy: 72.72 | Est: 300.34s - [Epoch 80/125] | Train Loss: 0.710| Train Accuracy: 80.74666666666667 | Val Loss: 0.958 | Val Accuracy: 73.29 | Est: 296.03s - [Epoch 90/125] | Train Loss: 0.632| Train Accuracy: 84.6225 | Val Loss: 0.949 | Val Accuracy: 73.57 | Est: 278.98s - [Epoch 100/125] | Train Loss: 0.689| Train Accuracy: 82.01333333333334 | Val Loss: 0.955 | Val Accuracy: 73.4 | Est: 309.88s - [Epoch 110/125] | Train Loss: 0.583| Train Accuracy: 85.4225 | Val Loss: 0.955 | Val Accuracy: 73.65 | Est: 299.26s - [Epoch 120/125] | Train Loss: 0.645| Train Accuracy: 82.21833333333333 | Val Loss: 0.956 | Val Accuracy: 73.44 | Est: 289.41s Highest Val Accuracy: 73.76 @ epoch 98 | Lowest Val Loss: 0.9380053132772446 @ epoch 115
_,_,_,_,_,df_augmentation2 = _train("SimpleNet-AutoAugment sample:5 (Coarse)", imageAutoAugment2,None, df_augmentation2,5,False,lr_=0.0125)
- [Epoch 10/125] | Train Loss: 1.393| Train Accuracy: 51.97125 | Val Loss: 1.234 | Val Accuracy: 61.34 | Est: 192.51s - [Epoch 20/125] | Train Loss: 1.013| Train Accuracy: 63.81 | Val Loss: 1.070 | Val Accuracy: 67.01 | Est: 221.98s - [Epoch 30/125] | Train Loss: 0.661| Train Accuracy: 74.85875 | Val Loss: 1.070 | Val Accuracy: 70.07 | Est: 240.78s - [Epoch 40/125] | Train Loss: 0.431| Train Accuracy: 82.76375 | Val Loss: 1.086 | Val Accuracy: 70.81 | Est: 243.14s - [Epoch 50/125] | Train Loss: 0.508| Train Accuracy: 80.76 | Val Loss: 1.053 | Val Accuracy: 71.56 | Est: 227.66s - [Epoch 60/125] | Train Loss: 0.324| Train Accuracy: 87.61875 | Val Loss: 1.075 | Val Accuracy: 71.23 | Est: 225.57s - [Epoch 70/125] | Train Loss: 0.282| Train Accuracy: 88.35875 | Val Loss: 1.093 | Val Accuracy: 71.67 | Est: 252.16s - [Epoch 80/125] | Train Loss: 0.337| Train Accuracy: 86.88125 | Val Loss: 1.083 | Val Accuracy: 72.26 | Est: 197.70s - [Epoch 90/125] | Train Loss: 0.280| Train Accuracy: 88.18875 | Val Loss: 1.097 | Val Accuracy: 72.2 | Est: 252.75s EarlyStopper triggered at epochs: 96 *No improvement to validation loss and accuracy could be seen for the past 18 epochs Highest Val Accuracy: 72.69 @ epoch 78 | Lowest Val Loss: 0.9853408843278885 @ epoch 22
_,_,_,_,_,df_augmentation2 = _train("SimpleNet-ImageNet sample:5 (Coarse)", imageAutoAugmentImageNet2,None, df_augmentation2,5,False,lr_=0.015)
- [Epoch 10/125] | Train Loss: 1.410| Train Accuracy: 51.975 | Val Loss: 1.279 | Val Accuracy: 59.55 | Est: 282.78s - [Epoch 20/125] | Train Loss: 0.946| Train Accuracy: 65.7675 | Val Loss: 1.089 | Val Accuracy: 66.43 | Est: 307.21s - [Epoch 30/125] | Train Loss: 0.643| Train Accuracy: 75.1275 | Val Loss: 1.045 | Val Accuracy: 69.34 | Est: 243.33s - [Epoch 40/125] | Train Loss: 0.507| Train Accuracy: 79.65625 | Val Loss: 1.077 | Val Accuracy: 69.6 | Est: 252.35s - [Epoch 50/125] | Train Loss: 0.504| Train Accuracy: 80.3 | Val Loss: 1.072 | Val Accuracy: 70.28 | Est: 246.14s - [Epoch 60/125] | Train Loss: 0.378| Train Accuracy: 84.555 | Val Loss: 1.186 | Val Accuracy: 69.91 | Est: 273.03s - [Epoch 70/125] | Train Loss: 0.348| Train Accuracy: 85.265 | Val Loss: 1.198 | Val Accuracy: 70.49 | Est: 256.97s - [Epoch 80/125] | Train Loss: 0.237| Train Accuracy: 89.8875 | Val Loss: 1.206 | Val Accuracy: 70.72 | Est: 289.31s - [Epoch 90/125] | Train Loss: 0.287| Train Accuracy: 87.77875 | Val Loss: 1.243 | Val Accuracy: 70.71 | Est: 297.99s EarlyStopper triggered at epochs: 99 *No improvement to validation loss and accuracy could be seen for the past 18 epochs Highest Val Accuracy: 71.22 @ epoch 81 | Lowest Val Loss: 1.0078613460063934 @ epoch 21
df_augmentation2
| Lowest Val Loss | Highest Val Acc | Lowest Train Loss | Epoch (Highest Val Acc) | Model Description | Parameter | Highest Train Acc | |
|---|---|---|---|---|---|---|---|
| 0 | 0.939657 | 74.06 | 0.378762 | 103 | SimpleNet-AutoAugment+CutMix sample:1 (Coarse) | 9,198,276 | 90.644167 |
| 0 | 0.938005 | 73.76 | 0.535824 | 98 | SimpleNet-ImageNet+CutMix sample:1 (Coarse) | 9,198,276 | 88.567500 |
| 0 | 0.985341 | 72.69 | 0.173510 | 78 | SimpleNet-AutoAugment sample:5 (Coarse) | 9,198,276 | 92.812500 |
| 0 | 1.007861 | 71.22 | 0.152900 | 81 | SimpleNet-ImageNet sample:5 (Coarse) | 9,198,276 | 92.907500 |
For fine labels: I will use AutoAugment (ImageNet augmentation policy) + CutMix augmentation with a sample mode of 1, where 1-third of the data is untouched, another 1-third of the data is autoaugmented with ImageNet policy and the last 1-third of the data is CutMix. The reason for this is because is able to provide a relatively high validation accuracy so far while obtaining the lowest training accuracy. Which shows that it is not overfitting as much as the other augmentation.
For coarse labels: I will use the same augmentation set and sample mode as well. AutoAugment (ImageNet augmentation policy) + CutMix augmentation seems to provide the least variance compared to other sets of augmentation.
TLDR: Sample mode 1 chosen; Augmentation sets are chosen based on least variance and high validation accuracy
Current highest validation accuracy:
def _train(model_,model_desc,df,fine=True,lr_=0.01,epoch_=125):
if fine:
parameters_desc = "-"
model = model_
base_optimizer = optim.SGD
optimizer = SAM(model.parameters(), base_optimizer, lr = lr_, momentum=0.90, weight_decay=0.00001)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
model = nn.DataParallel(model)
model.to(device)
train_loss, train_acc, val_loss, val_acc, epoch_his = train_sam(model, traindata_input,traindata_label, optimizer, epoch_, criterion,valloader,earlyStopper,augCompose1_=imageAutoAugmentImageNet, augCompose2_=None,scheduler=scheduler,augmentEvery=1,displayEvery=10,sampleMode=1,classNo=100)
else:
parameters_desc = "-"
model = model_
base_optimizer = optim.SGD
optimizer = SAM(model.parameters(), base_optimizer, lr = lr_, momentum=0.90, weight_decay=0.00001)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
model = nn.DataParallel(model)
model.to(device)
train_loss, train_acc, val_loss, val_acc, epoch_his = train_sam(model, traindata_input,trainlabel_coarse, optimizer, epoch_, criterion,valloader_coarse,earlyStopper,augCompose1_=imageAutoAugmentImageNet, augCompose2_=None,scheduler=scheduler,augmentEvery=1,displayEvery=10,sampleMode=1,classNo=20)
df = keep_history(df,train_acc,train_loss, val_loss, val_acc, epoch_his,model_desc,parameters_desc)
return train_loss, train_acc, val_loss, val_acc, epoch_his, df
df_complexity = pd.DataFrame([],columns=['Lowest Val Loss','Highest Val Acc','Lowest Train Loss','Epoch (Highest Val Acc)', 'Model Description', 'Parameter','Highest Train Acc'])
Using a simpler and more complex model of CoAtNet and SimpleNet and see whether they improve validation accuracy
simpleCoAtNet = CoAtNet((32, 32), 3, [2, 2, 3, 5, 2], [32, 48, 128, 256, 512], num_classes=100)
simpleCoAtNet2 = CoAtNet((32, 32), 3, [2, 2, 3, 5, 2], [64, 96, 192, 384, 768], num_classes=100)
complexCoAtNet = CoAtNet((32, 32), 3, [2, 2, 6, 14, 2], [192, 192, 384, 768, 1536], num_classes=100)
simplerSimpleNet = SimpleNet([2,3,3,3],3,32,2,1024,20)
complexSimpleNet = SimpleNet([2,3,5,5],3,32,2,2048,20)
_,_,_,_,_,df_complexity = _train(simpleCoAtNet,'Simpler Custom CoAtNet',df_complexity,fine=True,lr_=0.05)
- [Epoch 10/125] | Train Loss: 2.012| Train Accuracy: 50.763333333333335 | Val Loss: 1.825 | Val Accuracy: 51.5 | Est: 615.15s - [Epoch 20/125] | Train Loss: 1.399| Train Accuracy: 69.05666666666667 | Val Loss: 1.751 | Val Accuracy: 55.1 | Est: 625.75s - [Epoch 30/125] | Train Loss: 1.135| Train Accuracy: 75.46416666666667 | Val Loss: 1.719 | Val Accuracy: 56.49 | Est: 616.01s - [Epoch 40/125] | Train Loss: 1.088| Train Accuracy: 76.16166666666666 | Val Loss: 1.657 | Val Accuracy: 57.91 | Est: 631.53s - [Epoch 50/125] | Train Loss: 1.087| Train Accuracy: 78.65166666666667 | Val Loss: 1.668 | Val Accuracy: 57.4 | Est: 643.72s - [Epoch 60/125] | Train Loss: 1.091| Train Accuracy: 77.06333333333333 | Val Loss: 1.635 | Val Accuracy: 58.7 | Est: 642.08s - [Epoch 70/125] | Train Loss: 1.006| Train Accuracy: 79.43833333333333 | Val Loss: 1.637 | Val Accuracy: 58.47 | Est: 617.36s - [Epoch 80/125] | Train Loss: 0.942| Train Accuracy: 80.58666666666667 | Val Loss: 1.611 | Val Accuracy: 59.55 | Est: 653.44s - [Epoch 90/125] | Train Loss: 0.973| Train Accuracy: 80.71333333333334 | Val Loss: 1.612 | Val Accuracy: 59.66 | Est: 644.79s - [Epoch 100/125] | Train Loss: 0.935| Train Accuracy: 80.94666666666667 | Val Loss: 1.609 | Val Accuracy: 59.79 | Est: 590.29s EarlyStopper triggered at epochs: 105 *No improvement to validation loss and accuracy could be seen for the past 18 epochs Highest Val Accuracy: 59.95 @ epoch 87 | Lowest Val Loss: 1.60446697473526 @ epoch 87
Training started around epoch
train_loss, train_acc, val_loss, val_acc, epoch_his, df_complexity = _train(simpleCoAtNet2,'Simpler Custom CoAtNet2',df_complexity,fine=True,lr_=0.05)
- [Epoch 10/125] | Train Loss: 1.627| Train Accuracy: 61.01166666666666 | Val Loss: 1.763 | Val Accuracy: 53.99 | Est: 590.90s - [Epoch 20/125] | Train Loss: 1.260| Train Accuracy: 72.315 | Val Loss: 1.640 | Val Accuracy: 58.01 | Est: 591.54s - [Epoch 30/125] | Train Loss: 1.117| Train Accuracy: 76.33416666666666 | Val Loss: 1.608 | Val Accuracy: 58.97 | Est: 568.95s - [Epoch 40/125] | Train Loss: 0.885| Train Accuracy: 82.44 | Val Loss: 1.536 | Val Accuracy: 61.19 | Est: 603.43s - [Epoch 50/125] | Train Loss: 0.936| Train Accuracy: 81.10916666666667 | Val Loss: 1.534 | Val Accuracy: 61.58 | Est: 586.24s - [Epoch 60/125] | Train Loss: 0.977| Train Accuracy: 80.255 | Val Loss: 1.513 | Val Accuracy: 61.97 | Est: 595.57s - [Epoch 70/125] | Train Loss: 0.884| Train Accuracy: 82.6775 | Val Loss: 1.506 | Val Accuracy: 63.0 | Est: 587.57s - [Epoch 80/125] | Train Loss: 0.834| Train Accuracy: 84.0625 | Val Loss: 1.506 | Val Accuracy: 62.89 | Est: 584.46s - [Epoch 90/125] | Train Loss: 0.730| Train Accuracy: 87.56166666666667 | Val Loss: 1.499 | Val Accuracy: 62.89 | Est: 609.02s - [Epoch 100/125] | Train Loss: 0.755| Train Accuracy: 85.79 | Val Loss: 1.496 | Val Accuracy: 63.37 | Est: 580.43s - [Epoch 110/125] | Train Loss: 0.829| Train Accuracy: 83.95916666666666 | Val Loss: 1.484 | Val Accuracy: 63.71 | Est: 593.83s - [Epoch 120/125] | Train Loss: 0.676| Train Accuracy: 88.59833333333333 | Val Loss: 1.474 | Val Accuracy: 63.4 | Est: 588.89s Highest Val Accuracy: 63.85 @ epoch 123 | Lowest Val Loss: 1.4706898152828216 @ epoch 86
For more complex model I use higher batch size for faster computational time
BACTH_SIZE= 1024
train_loss2, train_acc2, val_loss2, val_acc2, epoch_his2, df_complexity = _train(complexCoAtNet,'Complex Custom CoAtNet',df_complexity,fine=True,lr_=0.007,epoch_=75)
- [Epoch 10/75] | Train Loss: 1.887| Train Accuracy: 54.98583333333333 | Val Loss: 1.991 | Val Accuracy: 48.71 | Est: 1115.97s - [Epoch 20/75] | Train Loss: 1.311| Train Accuracy: 69.46583333333334 | Val Loss: 1.840 | Val Accuracy: 52.42 | Est: 1127.45s - [Epoch 30/75] | Train Loss: 1.234| Train Accuracy: 71.2475 | Val Loss: 1.764 | Val Accuracy: 54.49 | Est: 1139.04s - [Epoch 40/75] | Train Loss: 1.207| Train Accuracy: 73.635 | Val Loss: 1.754 | Val Accuracy: 54.74 | Est: 1155.82s - [Epoch 50/75] | Train Loss: 1.162| Train Accuracy: 73.49 | Val Loss: 1.711 | Val Accuracy: 55.57 | Est: 1103.61s - [Epoch 60/75] | Train Loss: 1.167| Train Accuracy: 73.95583333333333 | Val Loss: 1.689 | Val Accuracy: 56.52 | Est: 1129.33s - [Epoch 70/75] | Train Loss: 1.114| Train Accuracy: 74.01583333333333 | Val Loss: 1.678 | Val Accuracy: 56.7 | Est: 1085.94s Highest Val Accuracy: 56.92 @ epoch 75 | Lowest Val Loss: 1.6766873955726624 @ epoch 73
df_complexity
| Lowest Val Loss | Highest Val Acc | Lowest Train Loss | Epoch (Highest Val Acc) | Model Description | Parameter | Highest Train Acc | |
|---|---|---|---|---|---|---|---|
| 0 | 1.604467 | 59.95 | 0.842774 | 87 | Simpler Custom CoAtNet | - | 84.537500 |
| 0 | 1.470690 | 63.85 | 0.607647 | 123 | Simpler Custom CoAtNet2 | - | 90.615833 |
| 0 | 1.676687 | 56.92 | 0.944950 | 75 | Complex Custom CoAtNet | - | 80.855000 |
BACTH_SIZE= 512
df_complexity_coarse = pd.DataFrame([],columns=['Lowest Val Loss','Highest Val Acc','Lowest Train Loss','Epoch (Highest Val Acc)', 'Model Description', 'Parameter'])
train_loss3, train_acc3, val_loss3, val_acc3, epoch_his3, df_complexity_coarse = _train(simplerSimpleNet,'Simpler SimpleNet',df_complexity_coarse,fine=False,lr_=0.018)
- [Epoch 10/125] | Train Loss: 1.105| Train Accuracy: 68.5925 | Val Loss: 1.000 | Val Accuracy: 69.56 | Est: 314.65s - [Epoch 20/125] | Train Loss: 0.859| Train Accuracy: 77.3025 | Val Loss: 0.973 | Val Accuracy: 71.91 | Est: 382.62s - [Epoch 30/125] | Train Loss: 0.747| Train Accuracy: 81.32333333333334 | Val Loss: 0.966 | Val Accuracy: 73.18 | Est: 376.21s - [Epoch 40/125] | Train Loss: 0.661| Train Accuracy: 83.67166666666667 | Val Loss: 0.962 | Val Accuracy: 73.2 | Est: 387.42s - [Epoch 50/125] | Train Loss: 0.674| Train Accuracy: 83.3075 | Val Loss: 0.950 | Val Accuracy: 73.87 | Est: 372.15s - [Epoch 60/125] | Train Loss: 0.690| Train Accuracy: 82.37583333333333 | Val Loss: 0.933 | Val Accuracy: 74.38 | Est: 382.38s - [Epoch 70/125] | Train Loss: 0.644| Train Accuracy: 84.28583333333333 | Val Loss: 0.945 | Val Accuracy: 74.12 | Est: 380.73s - [Epoch 80/125] | Train Loss: 0.682| Train Accuracy: 83.49916666666667 | Val Loss: 0.937 | Val Accuracy: 74.47 | Est: 366.12s - [Epoch 90/125] | Train Loss: 0.636| Train Accuracy: 85.21083333333333 | Val Loss: 0.925 | Val Accuracy: 74.7 | Est: 342.49s - [Epoch 100/125] | Train Loss: 0.687| Train Accuracy: 82.93416666666667 | Val Loss: 0.944 | Val Accuracy: 74.34 | Est: 371.32s - [Epoch 110/125] | Train Loss: 0.640| Train Accuracy: 84.44166666666666 | Val Loss: 0.925 | Val Accuracy: 74.71 | Est: 337.54s - [Epoch 120/125] | Train Loss: 0.605| Train Accuracy: 84.97833333333334 | Val Loss: 0.929 | Val Accuracy: 74.96 | Est: 359.46s Highest Val Accuracy: 75.03 @ epoch 119 | Lowest Val Loss: 0.9137001007795333 @ epoch 115
train_loss4, train_acc4, val_loss4, val_acc4, epoch_his4, df_complexity_coarse = _train(complexSimpleNet,'Complex SimpleNet',df_complexity_coarse,fine=False,lr_=0.02,epoch_=75)
- [Epoch 10/75] | Train Loss: 1.409| Train Accuracy: 57.10333333333333 | Val Loss: 1.205 | Val Accuracy: 62.71 | Est: 443.74s - [Epoch 20/75] | Train Loss: 0.983| Train Accuracy: 73.085 | Val Loss: 0.969 | Val Accuracy: 70.39 | Est: 410.56s - [Epoch 30/75] | Train Loss: 0.919| Train Accuracy: 75.14916666666667 | Val Loss: 0.967 | Val Accuracy: 72.02 | Est: 392.56s - [Epoch 40/75] | Train Loss: 0.750| Train Accuracy: 80.42333333333333 | Val Loss: 0.955 | Val Accuracy: 73.25 | Est: 461.82s - [Epoch 50/75] | Train Loss: 0.684| Train Accuracy: 83.02416666666667 | Val Loss: 0.949 | Val Accuracy: 73.29 | Est: 439.97s - [Epoch 60/75] | Train Loss: 0.684| Train Accuracy: 83.10083333333333 | Val Loss: 0.950 | Val Accuracy: 73.26 | Est: 400.40s - [Epoch 70/75] | Train Loss: 0.648| Train Accuracy: 84.48 | Val Loss: 0.944 | Val Accuracy: 73.93 | Est: 446.15s Highest Val Accuracy: 74.26 @ epoch 64 | Lowest Val Loss: 0.918514508008957 @ epoch 46
df_complexity_coarse
| Lowest Val Loss | Highest Val Acc | Lowest Train Loss | Epoch (Highest Val Acc) | Model Description | Parameter | Highest Train Acc | |
|---|---|---|---|---|---|---|---|
| 0 | 0.913700 | 75.03 | 0.536745 | 119 | Simpler SimpleNet | - | 87.925000 |
| 0 | 0.918515 | 74.26 | 0.592309 | 64 | Complex SimpleNet | - | 84.976667 |
def train_sam(model, label_ , optimizer, NUM_EPOCHS, criterion, valloader=valloader,earlystopper=None,showEpoch=True,scheduler=None,displayEvery=1,augBatch=1250,classNo=100):
val_loss_his = []
val_accuracy_his = []
epoch_his = []
earlystopper.reset()
model.train()
for epoch in range(NUM_EPOCHS):
t0 = time.time()
input_rep, label_rep = replicateData1(traindata=traindata_input,augCompose=imageAutoAugmentImageNet,augBatch=1250,trainlabel=label_,classNum_=classNo)
loader = DataLoader(TensorDataset(input_rep.type('torch.FloatTensor'),label_rep.type('torch.FloatTensor')), shuffle=False, batch_size=BATCH_SIZE)
for data in loader:
img, label = data
img = img.to(device)
label = label.to(device)
outputs = model(img)
loss = criterion(outputs, label) # use this loss for any training statistics
loss.backward()
optimizer.first_step(zero_grad=True)
outputs = model(img)
# second forward-backward pass
criterion(outputs, label).backward() # make sure to do a full forward pass
optimizer.second_step(zero_grad=True)
scheduler.step()
val_loss, val_accuracy = validate(model, criterion, valloader)
val_loss_his.append(val_loss)
val_accuracy_his.append(val_accuracy)
epoch_his.append(epoch+1)
if showEpoch and (epoch+1)%displayEvery == 0:
print(f'- [Epoch {epoch+1}/{NUM_EPOCHS}] | Train Loss: -| Train Accuracy: - | Val Loss: - | Val Accuracy: {val_accuracy} | Est: {(time.time() - t0)*displayEvery:.2f}s')
if earlyStopper(val_accuracy,val_loss):
print(f'EarlyStopper triggered at epochs: {epoch+1} \n*No improvement to validation loss and accuracy could be seen for the past {earlyStopper.patience} epochs')
break
print(f'Highest Val Accuracy: {max(val_accuracy_his)} @ epoch {epoch_his[val_accuracy_his.index(max(val_accuracy_his))]} | Lowest Val Loss: {min(val_loss_his)} @ epoch {epoch_his[val_loss_his.index(min(val_loss_his))]}')
return _,_, val_loss_his, val_accuracy_his,epoch_his
model_checkpoint_fine = CoAtNet((32, 32), 3, [2, 2, 3, 5, 2], [64, 96, 192, 384, 768], num_classes=100)
opt_base = optim.SGD
opt = SAM(model_checkpoint_fine.parameters(), opt_base, lr=0.008, momentum=0.9, weight_decay=0.001)
lambda1 = lambda epoch: 0.979 ** epoch
scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lambda1)
model_checkpoint_fine = nn.DataParallel(model_checkpoint_fine)
model_checkpoint_fine.to(device)
# train_sam(model, label_ , optimizer, NUM_EPOCHS, criterion, valloader=valloader,earlystopper=None,showEpoch=True,scheduler=None,displayEvery=1,augBatch=1250):
_,_,_,_,_ = train_sam(model_checkpoint_fine,traindata_label,opt,200,criterion,valloader,earlyStopper,displayEvery=10,scheduler=scheduler,classNo=100)
- [Epoch 10/200] | Train Loss: -| Train Accuracy: - | Val Loss: - | Val Accuracy: 50.36 | Est: 723.76s - [Epoch 20/200] | Train Loss: -| Train Accuracy: - | Val Loss: - | Val Accuracy: 56.48 | Est: 669.04s - [Epoch 30/200] | Train Loss: -| Train Accuracy: - | Val Loss: - | Val Accuracy: 57.31 | Est: 687.76s - [Epoch 40/200] | Train Loss: -| Train Accuracy: - | Val Loss: - | Val Accuracy: 59.34 | Est: 761.98s - [Epoch 50/200] | Train Loss: -| Train Accuracy: - | Val Loss: - | Val Accuracy: 59.9 | Est: 747.61s - [Epoch 60/200] | Train Loss: -| Train Accuracy: - | Val Loss: - | Val Accuracy: 60.64 | Est: 712.32s - [Epoch 70/200] | Train Loss: -| Train Accuracy: - | Val Loss: - | Val Accuracy: 61.25 | Est: 670.57s - [Epoch 80/200] | Train Loss: -| Train Accuracy: - | Val Loss: - | Val Accuracy: 61.35 | Est: 712.10s - [Epoch 90/200] | Train Loss: -| Train Accuracy: - | Val Loss: - | Val Accuracy: 61.52 | Est: 731.11s - [Epoch 100/200] | Train Loss: -| Train Accuracy: - | Val Loss: - | Val Accuracy: 62.0 | Est: 713.41s - [Epoch 110/200] | Train Loss: -| Train Accuracy: - | Val Loss: - | Val Accuracy: 62.84 | Est: 771.24s - [Epoch 120/200] | Train Loss: -| Train Accuracy: - | Val Loss: - | Val Accuracy: 62.75 | Est: 689.82s - [Epoch 130/200] | Train Loss: -| Train Accuracy: - | Val Loss: - | Val Accuracy: 63.2 | Est: 736.62s - [Epoch 140/200] | Train Loss: -| Train Accuracy: - | Val Loss: - | Val Accuracy: 62.94 | Est: 692.31s - [Epoch 150/200] | Train Loss: -| Train Accuracy: - | Val Loss: - | Val Accuracy: 63.13 | Est: 655.27s - [Epoch 160/200] | Train Loss: -| Train Accuracy: - | Val Loss: - | Val Accuracy: 63.55 | Est: 731.10s - [Epoch 170/200] | Train Loss: -| Train Accuracy: - | Val Loss: - | Val Accuracy: 63.42 | Est: 740.34s - [Epoch 180/200] | Train Loss: -| Train Accuracy: - | Val Loss: - | Val Accuracy: 63.26 | Est: 716.81s - [Epoch 190/200] | Train Loss: -| Train Accuracy: - | Val Loss: - | Val Accuracy: 63.27 | Est: 760.95s - [Epoch 200/200] | Train Loss: -| Train Accuracy: - | Val Loss: - | Val Accuracy: 63.65 | Est: 778.86s Highest Val Accuracy: 63.94 @ epoch 182 | Lowest Val Loss: 1.412428170442581 @ epoch 198
torch.save(model_checkpoint_fine, 'save1')
def loadTunerState(stateData = 'hyperparmState.csv'):
with open(os.path.join(sys.path[0], stateData),'r') as file:
highest_val_acc = []
val_acc_hist = []
trial_hist = []
total_time = 0
rawTxt = file.read()
splited = rawTxt.split('\n')
for val in splited[0].split(','):
highest_val_acc.extend([float(val)])
for val in splited[1].split(','):
val_acc_hist.extend([float(val)])
c = 0
trial = []
for val in splited[2].split(','):
c += 1
trial.extend([float(val)])
if c%4==0:
trial_hist.append(trial)
trial = []
total_time = float(splited[3])
saved_time = float(splited[4])
file.close()
return highest_val_acc, val_acc_hist, trial_hist, total_time, saved_time
def saveTunerState(highest_val_acc,val_acc_hist,trial_hist,stoppedTime,stateData = 'hyperparmState.csv'):
inp = ''
with open(os.path.join(sys.path[0], stateData),'w') as file:
for val in highest_val_acc:
inp += f'{val},'
file.write(inp[:-1]+'\n')
inp = ''
for val in val_acc_hist:
inp += f'{val},'
file.write(inp[:-1]+'\n')
inp = ''
for val in trial_hist:
for val2 in val:
inp += f'{val2},'
file.write(inp[:-1]+'\n'+str(stoppedTime)+'\n'+str(time.time()))
file.close()
return
lambda1 = lambda epoch: 0.97 ** epoch
def NetRandomTuner(hyperparameterCacheFile = None,saveAs = 'hyperparmState.csv',cacheEvery=1,LR_range = np.logspace(1,3,num=13)/100000 * 0.1,WD_range=np.logspace(1,3,num=10)/100000000,Beta_range=np.linspace(0.88,0.96, num=5),trials=32,epoch=35,label_=traindata_label,valloader=valloader,classNum=100):
global checkpoint_best
possible_trials=[]
for LR, WD, B3 in itertools.product(*(LR_range,WD_range,Beta_range)):
possible_trials.append([LR, WD, B3])
# shuffle all possible trials
random.shuffle(possible_trials)
if hyperparameterCacheFile == None:
highest_val_acc = [0,0,0]
val_acc_hist = []
trial_hist = []
loaded_trial_hist = []
trial_count = 0
t0 = time.time()
else:
highest_val_acc, val_acc_hist, trial_hist, t0, time_saved = loadTunerState(hyperparameterCacheFile)
# New t0 = time of loading - time of saving
t0 += time.time() - time_saved
trial_count = len(val_acc_hist)
highest_val_acc[1] = int(highest_val_acc[1]) # just making sure is correct datatype
highest_val_acc[2] = int(highest_val_acc[2])
loaded_trial_hist = []
for row in trial_hist:
loaded_trial_hist.append(row[:3])
opt_base = optim.SGD
for trial in possible_trials:
# If trial is already ran from pervious history, skip trial and go to next
if trial in loaded_trial_hist:
pass
else:
print('Next trial: ',trial)
t1 = time.time()
if trial_count == trials:
print(f'\n\nTrial ended at trial #{trial_count}')
break
trial_count += 1
# A lot more options to tune if you want. Refer to section 4.2 discussion (growthrate,filters,neurons)
model = deepcopy(model_checkpoint_fine)
model.to(device)
opt = SAM(model.parameters(), opt_base, lr=trial[0], momentum=trial[2], weight_decay=trial[1])
scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lambda1)
_, _, _, val_acc, epoch_his = train_sam(model,label_,opt,epoch,criterion,valloader,earlyStopper,displayEvery=1,scheduler=scheduler,classNo=classNum)
if max(val_acc) > highest_val_acc[0]:
highest_val_acc = [max(val_acc),trial_count-1,epoch_his[val_acc.index(max(val_acc))]]
checkpoint_best = deepcopy(model)
trial_hist.append([trial[0],trial[1],trial[2],epoch_his[val_acc.index(max(val_acc))]])
val_acc_hist.extend([max(val_acc)])
clear_output()
print(f'''
Trial #{trial_count} Finished - Search Time {(time.time()-t1)/60:.2f} Mins
Total Time Elapsed: {(time.time()-t0)/60:.2f} Mins\n
Hyperparameters\t\t|Trial Values: #{trial_count}\t|Best Trial Values: #{highest_val_acc[1]+1}\n
Learning Rate\t\t|{trial[0]:.6f}\t\t|{trial_hist[highest_val_acc[1]][0]:.6f}
Weight Decay (L2)\t|{trial[1]:.6f}\t\t|{trial_hist[highest_val_acc[1]][1]:.6f}
Momentum\t\t|{trial[2]:.6f}\t\t|{trial_hist[highest_val_acc[1]][2]:.6f}
Highest Val Acc\t\t|{max(val_acc)}\t\t\t|{highest_val_acc[0]:.2f}
Epoch (Highest Val)\t|{epoch_his[val_acc.index(max(val_acc))]}\t\t\t|{highest_val_acc[2]}\n\n
''')
if trial_count % cacheEvery == 0:
saveTunerState(highest_val_acc,val_acc_hist,trial_hist,t0,stateData = saveAs)
return trial_hist, val_acc_hist
trial_hist1, val_acc_hist1 = NetRandomTuner('hyperparmState.csv')
Trial #32 Finished - Search Time 27.56 Mins Total Time Elapsed: 1002.02 Mins Hyperparameters |Trial Values: #32 |Best Trial Values: #20 Learning Rate |0.000316 |0.000100 Weight Decay (L2) |0.000000 |0.000006 Momentum |0.920000 |0.920000 Highest Val Acc |67.26 |69.52 Epoch (Highest Val) |40 |46 Next trial: [0.00021544346900318823, 4.6415888336127773e-07, 0.9199999999999999] Trial ended at trial #32
def hyperparmsOverview(val_acc_hist, trial_hist, returnNum):
df_top = pd.DataFrame([],columns=['Highest val acc','LR','Weight decay','Momentum','Epoch (of Highest Val Acc)'])
df_btm = pd.DataFrame([],columns=['Highest val acc','LR','Weight decay','Momentum','Epoch (of Highest Val Acc)'])
top = sorted(zip(val_acc_hist, trial_hist), reverse=True)[:returnNum]
btm = sorted(zip(val_acc_hist, trial_hist), reverse=False)[:returnNum]
for i,e in zip(top,btm):
df_top = pd.concat([df_top,pd.DataFrame([[i[0],i[1][0],i[1][1],i[1][2],i[1][3]]],columns=['Highest val acc','LR','Weight decay','Momentum','Epoch (of Highest Val Acc)'])])
df_btm = pd.concat([df_btm,pd.DataFrame([[e[0],e[1][0],e[1][1],e[1][2],e[1][3]]],columns=['Highest val acc','LR','Weight decay','Momentum','Epoch (of Highest Val Acc)'])])
print(f'Top {returnNum} Highest Val Accuracy Hyperparameters')
display(df_top)
print(f'\n\nBottom {returnNum} Highest Val Accuracy Hyperparameters')
display(df_btm)
hyperparmsOverview(val_acc_hist1,trial_hist1,5)
Top 5 Highest Val Accuracy Hyperparameters
| Highest val acc | LR | Weight decay | Momentum | Epoch (of Highest Val Acc) | |
|---|---|---|---|---|---|
| 0 | 69.52 | 0.000100 | 0.000006 | 0.96 | 46.0 |
| 0 | 69.47 | 0.000147 | 0.000008 | 0.90 | 42.0 |
| 0 | 68.47 | 0.000015 | 0.000005 | 0.90 | 41.0 |
| 0 | 68.45 | 0.000032 | 0.000002 | 0.96 | 40.0 |
| 0 | 68.43 | 0.000016 | 0.000001 | 0.88 | 47.0 |
Bottom 5 Highest Val Accuracy Hyperparameters
| Highest val acc | LR | Weight decay | Momentum | Epoch (of Highest Val Acc) | |
|---|---|---|---|---|---|
| 0 | 67.23 | 0.000046 | 1.291550e-06 | 0.90 | 9.0 |
| 0 | 67.26 | 0.000316 | 1.668101e-07 | 0.92 | 40.0 |
| 0 | 67.27 | 0.000681 | 1.000000e-06 | 0.96 | 9.0 |
| 0 | 67.28 | 0.000681 | 5.994843e-06 | 0.94 | 9.0 |
| 0 | 67.29 | 0.000032 | 1.000000e-06 | 0.94 | 33.0 |
Analysis & Take away:
Generally higher weight decay leads to better result. Seems like my first trial is the best result. Furthermore, I have set the epochs limit to 50. Hence, my trials are technically trained on an average of 240 epochs~, a higher epoch limit could be set in an attempt to achieve better results.
Discussion:
I have already done hyperparameter tuning on my fine labelled model and has decent feature extraction capabilities on the CIFAR100 already, is there anyway I can 'transfer' its feature extraction capabilities and train it for my coarse labelled data, especially when the training data is exactly the same. Perhaps I can try transfer learning of my fine labelled model and change the last output layer and compare the results against my pervious highest validation without transfer learning. Of course transfer learning of SOTA model is out of the question (many, if not all SOTA models, are trained with extra data or finetuned with weights that contained images with extra data).
Current highest validation accuracy:
# Copy our best model during hyperparameter tuning
finetune_coarse = deepcopy(checkpoint_best)
# Remove model from nn.DataParallel
finetune_coarse = finetune_coarse.module
finetune_coarse.fc
Linear(in_features=768, out_features=100, bias=False)
# Change last layer
finetune_coarse.fc = nn.Linear(768, 20, bias=False)
# model, input_, label_ , optimizer, NUM_EPOCHS, criterion, valloader=None,earlystopper=None,showEpoch=True,scheduler=None,displayEvery=1,augmentEvery=None,augCompose1_=imageAutoAugment,augCompose2_=imageAutoAugment,sampleMode=2,augBatch=1250,classNo=100
opt_base = optim.SGD
opt = SAM(finetune_coarse.parameters(), opt_base, lr=0.0125, momentum=0.9, weight_decay=0.00001)
scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lambda1)
finetune_coarse_save = deepcopy(finetune_coarse)
finetune_coarse = nn.DataParallel(finetune_coarse)
finetune_coarse.to(device)
_,_,_,_,_ = train_sam(finetune_coarse,traindata_input,trainlabel_coarse,opt,70,criterion,valloader_coarse,earlyStopper,scheduler=scheduler,augmentEvery=1,augCompose1_=imageAutoAugmentImageNet,sampleMode=1,classNo=20,displayEvery=5)
- [Epoch 5/70] | Train Loss: 0.505| Train Accuracy: 88.33583333333333 | Val Loss: 0.857 | Val Accuracy: 75.79 | Est: 380.75s - [Epoch 10/70] | Train Loss: 0.492| Train Accuracy: 90.48 | Val Loss: 0.764 | Val Accuracy: 78.96 | Est: 371.58s - [Epoch 15/70] | Train Loss: 0.475| Train Accuracy: 90.96333333333334 | Val Loss: 0.729 | Val Accuracy: 79.09 | Est: 358.75s - [Epoch 20/70] | Train Loss: 0.455| Train Accuracy: 91.21166666666667 | Val Loss: 0.730 | Val Accuracy: 79.33 | Est: 357.70s - [Epoch 25/70] | Train Loss: 0.464| Train Accuracy: 90.43166666666667 | Val Loss: 0.722 | Val Accuracy: 79.41 | Est: 364.54s - [Epoch 30/70] | Train Loss: 0.388| Train Accuracy: 92.63166666666666 | Val Loss: 0.730 | Val Accuracy: 79.11 | Est: 358.24s - [Epoch 35/70] | Train Loss: 0.435| Train Accuracy: 91.93916666666667 | Val Loss: 0.741 | Val Accuracy: 79.01 | Est: 344.52s - [Epoch 40/70] | Train Loss: 0.499| Train Accuracy: 89.32333333333334 | Val Loss: 0.726 | Val Accuracy: 79.21 | Est: 365.33s EarlyStopper triggered at epochs: 42 *No improvement to validation loss and accuracy could be seen for the past 18 epochs Highest Val Accuracy: 79.57 @ epoch 24 | Lowest Val Loss: 0.7164515900611877 @ epoch 16
Seems like there indeed is an improvement using transfer learning from my fine to coarse labelled model. 75% -> 79.57%
finetune_coarse_save = deepcopy(finetune_coarse)
Preparing test data & labels
testloader = DataLoader(TensorDataset(test_data.type('torch.FloatTensor'),test_label.type('torch.FloatTensor')), shuffle=False, batch_size=BATCH_SIZE)
testloader_coarse = DataLoader(TensorDataset(test_data.type('torch.FloatTensor'),test_label_coarse.type('torch.FloatTensor')), shuffle=False, batch_size=BATCH_SIZE)
def eval(model,criterion,test_loader,title,printHM=False,classlabel=class_labels):
correct = 0
total = 0
print(title)
y_pred = []
y_true = []
# Torch tensor has to be on 1 GPU device to concat, so I selected GPU:0
cuda0 = torch.device('cuda:0')
wrong_samples = torch.empty(0, device=cuda0)
wrong_preds = torch.empty(0, device=cuda0)
actual_preds = torch.empty(0, device=cuda0)
running_loss = []
model.eval()
for i, data in enumerate(test_loader):
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
y_pred.extend(outputs.argmax(-1).tolist())
y_true.extend(labels.argmax(-1).tolist())
correct += (outputs.argmax(-1) == labels.argmax(-1)).sum().item()
running_loss.append(loss.item())
total += labels.size(0)
wrong_mask = outputs.argmax(-1) != labels.argmax(-1)
wrong_samples = torch.cat((wrong_samples,inputs[wrong_mask]))
wrong_preds = torch.cat((wrong_preds,outputs.argmax(-1)[wrong_mask]))
actual_preds = torch.cat((actual_preds,labels.argmax(-1)[wrong_mask]))
test_accuracy = (100 * correct / total)
test_loss = sum(running_loss) / len(running_loss)
model.train()
if printHM:
fig, ax = plt.subplots(figsize=(13,13))
s = sns.heatmap(confusion_matrix(y_true, y_pred), cmap='Blues', annot=True, fmt='d',ax=ax,
xticklabels=classlabel,
yticklabels=classlabel)
s.set_xlabel('Predicted', fontsize=16)
s.set_ylabel('Actual', fontsize=16)
plt.show()
return test_loss, test_accuracy, wrong_samples, wrong_preds, actual_preds, classification_report(y_true, y_pred,target_names=classlabel,output_dict=True)
def displayLowestF1(classification_dict,classes=100,returnLowest=10):
temp_list = []
temp_list2 = []
temp_i = 0
for keys, values in classification_dict.items():
temp_i += 1
temp_list.append(values['f1-score'])
temp_list2.append(keys)
if temp_i == classes:
temp_i = 0
break
for value,class_ in sorted(zip(temp_list, temp_list2), reverse=False):
temp_i += 1
print(class_,classification_dict[class_])
if temp_i == returnLowest:
break
Final training for fine labelled data
criterion = nn.CrossEntropyLoss()
# it would be better to train from scratch as we are using more data now. (training + validation combined)
finalmodel_fine = deepcopy(model_checkpoint_fine)
opt_base = optim.SGD
if torch.cuda.device_count() > 1:
finalmodel_fine = nn.DataParallel(finalmodel_fine)
finalmodel_fine.to(device)
else:
finalmodel_fine.to(device)
optimizer = SAM(finalmodel_fine.parameters(), opt_base, lr=0.0001, momentum=0.92, weight_decay=0.000006)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
for epoch in range(46):
input_rep, label_rep = replicateData1(traindata=train_data,augCompose=imageAutoAugmentImageNet,augBatch=1000,trainlabel=train_label,classNum_=100)
loader = DataLoader(TensorDataset(input_rep.type('torch.FloatTensor'),label_rep.type('torch.FloatTensor')), shuffle=False, batch_size=BATCH_SIZE)
for data in loader:
img, label = data
img = img.to(device)
label = label.to(device)
outputs = finalmodel_fine(img)
loss = criterion(outputs, label) # use this loss for any training statistics
loss.backward()
optimizer.first_step(zero_grad=True)
outputs = finalmodel_fine(img)
# second forward-backward pass
criterion(outputs, label).backward() # make sure to do a full forward pass
optimizer.second_step(zero_grad=True)
scheduler.step()
print('Finished Training.')
Finished Training.
Final training for coarse label:
criterion = nn.CrossEntropyLoss()
finalmodel_coarse = deepcopy(finalmodel_fine)
finalmodel_coarse = finalmodel_coarse.module
finalmodel_coarse.fc = finetune_coarse.fc = nn.Linear(768, 20, bias=False)
opt_base = optim.SGD
if torch.cuda.device_count() > 1:
finalmodel_coarse = nn.DataParallel(finalmodel_coarse)
finalmodel_coarse.to(device)
else:
finalmodel_fine.to(device)
optimizer = SAM(finetune_coarse.parameters(), opt_base, lr=0.0125, momentum=0.9, weight_decay=0.00001)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
for epoch in range(25):
input_rep, label_rep = replicateData1(traindata=train_data,augCompose=imageAutoAugmentImageNet,augBatch=1000,trainlabel=train_label_coarse,classNum_=20)
loader = DataLoader(TensorDataset(input_rep.type('torch.FloatTensor'),label_rep.type('torch.FloatTensor')), shuffle=False, batch_size=BATCH_SIZE)
for data in loader:
img, label = data
img = img.to(device)
label = label.to(device)
outputs = finalmodel_coarse(img)
loss = criterion(outputs, label) # use this loss for any training statistics
loss.backward()
optimizer.first_step(zero_grad=True)
outputs = finalmodel_coarse(img)
# second forward-backward pass
criterion(outputs, label).backward() # make sure to do a full forward pass
optimizer.second_step(zero_grad=True)
scheduler.step()
print('Finished Training Coarse Labelled Data.')
Finished Training Coarse Labelled Data.
test_loss, test_accuracy, wrong_samples, wrong_preds, actual_preds, class_dict = eval(finalmodel_fine,criterion,testloader,'Final Model')
print('\nFinal model (fine) test_accuracy (Top-1 accuracy): ',test_accuracy)
print('\nFinal model (fine) test_loss: ',test_loss)
Final Model Final model (fine) test_accuracy (Top-1 accuracy): 69.64 Final model (fine) test_loss: 1.30165576338768
Display only top 10 lowest f1 scores
displayLowestF1(class_dict)
otter {'precision': 0.3568181818181818, 'recall': 0.32, 'f1-score': 0.3372340425531915, 'support': 100}
lizard {'precision': 0.40082474226804123, 'recall': 0.39, 'f1-score': 0.3953299492385787, 'support': 100}
seal {'precision': 0.46025641025641024, 'recall': 0.36, 'f1-score': 0.39955056179775285, 'support': 100}
girl {'precision': 0.4657303370786517, 'recall': 0.41, 'f1-score': 0.4315343915343915, 'support': 100}
man {'precision': 0.4575824175824176, 'recall': 0.42, 'f1-score': 0.4379057591623037, 'support': 100}
woman {'precision': 0.44, 'recall': 0.44, 'f1-score': 0.4400000000000001, 'support': 100}
shrew {'precision': 0.43622641509433965, 'recall': 0.46, 'f1-score': 0.4477669902912621, 'support': 100}
mouse {'precision': 0.4456603773584906, 'recall': 0.47, 'f1-score': 0.45747572815533984, 'support': 100}
boy {'precision': 0.4670833333333333, 'recall': 0.45, 'f1-score': 0.4583673469387755, 'support': 100}
bear {'precision': 0.50987951807228917, 'recall': 0.43, 'f1-score': 0.46622950819672134, 'support': 100}
test_loss2, test_accuracy2, wrong_samples2, wrong_preds2, actual_preds2, class_dict2 = eval(finetune_coarse,criterion,testloader_coarse,'Final Model',printHM=True,classlabel=superclass_labels)
print('\nFinal model (coarse) test_accuracy (Top-1 accuracy): ',test_accuracy2)
print('\nFinal model (coarse) test_loss: ',test_loss2)
Final Model
Final model (coarse) test_accuracy (Top-1 accuracy): 79.65 Final model (coarse) test_loss: 0.7096987730026245
displayLowestF1(class_dict2,classes=20)
b'reptiles' {'precision': 0.644887983706721, 'recall': 0.634, 'f1-score': 0.6393945509586277, 'support': 500}
b'small_mammals' {'precision': 0.6481226053639846, 'recall': 0.674, 'f1-score': 0.6607827788649706, 'support': 500}
b'aquatic_mammals' {'precision': 0.671980198019802, 'recall': 0.678, 'f1-score': 0.6749751243781095, 'support': 500}
b'medium_mammals' {'precision': 0.7826696832579186, 'recall': 0.71, 'f1-score': 0.7399898089171974, 'support': 500}
b'large_carnivores' {'precision': 0.7027586206896552, 'recall': 0.804, 'f1-score': 0.7496296296296296, 'support': 500}
b'non-insect_invertebrates' {'precision': 0.8040909090909091, 'recall': 0.716, 'f1-score': 0.7572340425531914, 'support': 500}
b'large_omnivores_and_herbivores' {'precision': 0.7487072243346007, 'recall': 0.784, 'f1-score': 0.7659064327485379, 'support': 500}
b'household_electrical_devices' {'precision': 0.8704866180048662, 'recall': 0.728, 'f1-score': 0.7922832052689353, 'support': 500}
b'fish' {'precision': 0.8202061855670103, 'recall': 0.798, 'f1-score': 0.8089340101522843, 'support': 500}
b'insects' {'precision': 0.8826905829596412, 'recall': 0.796, 'f1-score': 0.8368710359408033, 'support': 500}
Fine label wrongly predicted values
random_idxs = np.random.choice(wrong_preds.shape[0], 50, replace=False)
fig, ax = plt.subplots(10, 5, figsize=(22, 36))
plt.axis("off")
for idx, subplot in zip(random_idxs, ax.ravel()):
pred = class_labels[wrong_preds.type('torch.LongTensor')[idx]]
actual = class_labels[actual_preds.type('torch.LongTensor')[idx]]
subplot.imshow(img4np(wrong_samples.cpu())[idx])
subplot.set_title(f"Actual: {actual}\nPredicted: {pred}",fontsize=16)
plt.tight_layout()
Coarse label wrongly predicted labels
random_idxs = np.random.choice(wrong_preds2.shape[0], 25, replace=False)
fig, ax = plt.subplots(4, 5, figsize=(42, 24))
plt.axis("off")
for idx, subplot in zip(random_idxs, ax.ravel()):
pred = superclass_labels[wrong_preds2.type('torch.LongTensor')[idx]]
actual = superclass_labels[actual_preds2.type('torch.LongTensor')[idx]]
subplot.imshow(img4np(wrong_samples2.cpu())[idx])
subplot.set_title(f"Actual:{actual}\nPredicted:{pred}",fontsize=16)
# plt.tight_layout()
Fine label feature mapping: Showing only fine model feature mapping as it will almost be the same as our coarse model feature map...since I did transfer learning from my fine model, during model improvement.
layers = []
for parm in finalmodel_fine.parameters():
layers.append(parm)
First layer has [64 output, 3 input channels, 3 kernel width, 3 kernel height]
layers[0].shape
torch.Size([64, 3, 3, 3])
from torchvision import utils
def visTensor(tensor, ch=0, allkernels=False, nrow=8, padding=1):
n,c,w,h = tensor.shape
if allkernels: tensor = tensor.view(n*c, -1, w, h)
elif c != 3: tensor = tensor[:,ch,:,:].unsqueeze(dim=1)
rows = np.min((tensor.shape[0] // nrow + 1, 64))
grid = utils.make_grid(tensor, nrow=nrow, normalize=True, padding=padding)
plt.figure( figsize=(nrow,rows) )
plt.imshow(grid.numpy().transpose((1, 2, 0)))
filter = layers[0].data.cpu().clone()
visTensor(filter, ch=0, allkernels=False)
plt.axis('off')
plt.ioff()
plt.show()
Unlike in Part A, the weights of the filters are far harder to comprehen as there are now 3 colors with very differenting shades to display.
def featureMap(conv,image_input,plotX=8,plotY=8):
image_np = image_input.cpu().numpy()
cuda0 = torch.device('cuda:0')
# image_map = F.conv2d(image_input.to(cuda0), convWeights.data.to(cuda0), padding=1)
image_map = conv(image_input.to(cuda0))
image_map = image_map.cpu().detach().numpy()[0]
# print(image_map.shape)
fig, ax = plt.subplots(plotY, plotX, figsize=(20, 20))
for i, subplot in enumerate(ax.ravel()):
image = image_map[i]
subplot.imshow(image,cmap='gray')
subplot.set_title(f"Filter {i+1}")
subplot.axis("off")
plt.show()
My first layer of my final model
finalmodel_fine.module.s0[0][0]
Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
First layer of our model, with activation function
finalmodel_fine.module.s0[0]
Sequential( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): GELU(approximate='none') )
Original
plt.imshow(traindata_input[0].numpy()[0],cmap='Reds')
plt.show()
plt.imshow(traindata_input[0].numpy()[1],cmap='Greens')
plt.show()
plt.imshow(traindata_input[0].numpy()[2],cmap='Blues')
plt.show()
plt.imshow(img4np(torch.unsqueeze(traindata_input[0], dim=0))[0])
plt.show()
print('Feature map of cattle (Layer1 before activation function)')
featureMap(finalmodel_fine.module.s0[0][0],torch.unsqueeze(traindata_input[0], dim=0),8,8)
print('Feature map of cattle (Layer1 after activation function)')
featureMap(finalmodel_fine.module.s0[0],torch.unsqueeze(traindata_input[0], dim=0),8,8)
Feature map of cattle (Layer1 before activation function)
Feature map of cattle (Layer1 after activation function)
torch.save(finalmodel_fine,'finalmodel_fine')
torch.save(finalmodel_coarse,'finalmodel_coarse')
You have reach the end of Part B :))